Compare commits

..

6 Commits

Author SHA1 Message Date
bcmmbaga
4bed26e416 Add diff and hash tests for ignored tags 2024-07-02 13:29:46 +03:00
bcmmbaga
67cc8bd655 add tests 2024-07-02 12:57:42 +03:00
bcmmbaga
42be72a86c Replace hashstructure package with r3labs/diff for network map updates 2024-06-28 15:44:38 +03:00
bcmmbaga
16387a823a Reset timer in benchmark test functions 2024-06-27 17:14:22 +03:00
bcmmbaga
b4dddc8d0f Add server account peer update functions and tests 2024-06-27 00:55:48 +03:00
bcmmbaga
7a0dc10ccc Add network map hash to avoid unnecessary updates 2024-06-26 16:29:10 +03:00
149 changed files with 3516 additions and 3738 deletions

View File

@@ -173,7 +173,7 @@ jobs:
retention-days: 3
release_ui_darwin:
runs-on: macos-latest
runs-on: macos-11
steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV

View File

@@ -178,79 +178,34 @@ jobs:
- name: Checkout code
uses: actions/checkout@v3
- name: run script with Zitadel PostgreSQL
- name: run script
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
- name: test Caddy file gen postgres
- name: test Caddy file gen
run: test -f Caddyfile
- name: test docker-compose file gen postgres
- name: test docker-compose file gen
run: test -f docker-compose.yml
- name: test management.json file gen postgres
- name: test management.json file gen
run: test -f management.json
- name: test turnserver.conf file gen postgres
- name: test turnserver.conf file gen
run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
- name: test zitadel.env file gen postgres
- name: test zitadel.env file gen
run: test -f zitadel.env
- name: test dashboard.env file gen postgres
- name: test dashboard.env file gen
run: test -f dashboard.env
- name: test zdb.env file gen postgres
run: test -f zdb.env
- name: Postgres run cleanup
run: |
docker-compose down --volumes --rmi all
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
- name: run script with Zitadel CockroachDB
run: bash -x infrastructure_files/getting-started-with-zitadel.sh
env:
NETBIRD_DOMAIN: use-ip
ZITADEL_DATABASE: cockroach
- name: test Caddy file gen CockroachDB
run: test -f Caddyfile
- name: test docker-compose file gen CockroachDB
run: test -f docker-compose.yml
- name: test management.json file gen CockroachDB
run: test -f management.json
- name: test turnserver.conf file gen CockroachDB
run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
- name: test zitadel.env file gen CockroachDB
run: test -f zitadel.env
- name: test dashboard.env file gen CockroachDB
run: test -f dashboard.env
test-download-geolite2-script:
runs-on: ubuntu-latest
steps:
- name: Install jq
run: sudo apt-get update && sudo apt-get install -y unzip sqlite3
- name: Checkout code
uses: actions/checkout@v3
- name: test script
run: bash -x infrastructure_files/download-geolite2.sh
- name: test mmdb file exists
run: test -f GeoLite2-City.mmdb
- name: test geonames file exists
run: test -f geonames.db

View File

@@ -3,10 +3,8 @@ builds:
- id: netbird-ui-darwin
dir: client/ui
binary: netbird-ui
env:
- CGO_ENABLED=1
- MACOSX_DEPLOYMENT_TARGET=11.0
- MACOS_DEPLOYMENT_TARGET=11.0
env: [CGO_ENABLED=1]
goos:
- darwin
goarch:

View File

@@ -1,4 +1,4 @@
FROM alpine:3.19
FROM alpine:3.18.5
RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"]

View File

@@ -76,7 +76,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
t.Fatal(err)
}
s := grpc.NewServer()
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
store, cleanUp, err := mgmt.NewTestStoreFromJson(config.Datadir)
if err != nil {
t.Fatal(err)
}
@@ -87,13 +87,13 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
if err != nil {
return nil, nil
}
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
iv, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -74,12 +74,12 @@ func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
return nil
}
err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
err = i.insertRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
if err != nil {
return err
}
err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
err = i.insertRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
if err != nil {
return err
}
@@ -101,7 +101,6 @@ func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string,
}
delete(i.rules, ruleKey)
}
err = i.iptablesClient.Insert(table, chain, 1, rule...)
if err != nil {
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
@@ -318,13 +317,6 @@ func (i *routerManager) createChain(table, newChain string) error {
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
}
// Add the loopback return rule to the NAT chain
loopbackRule := []string{"-o", "lo", "-j", "RETURN"}
err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...)
if err != nil {
return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err)
}
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
if err != nil {
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
@@ -334,30 +326,6 @@ func (i *routerManager) createChain(table, newChain string) error {
return nil
}
// addNATRule appends an iptables rule pair to the nat chain
func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, pair.Source, pair.Destination)
existingRule, found := i.rules[ruleKey]
if found {
err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
}
delete(i.rules, ruleKey)
}
// inserting after loopback ignore rule
err := i.iptablesClient.Insert(table, chain, 2, rule...)
if err != nil {
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
}
i.rules[ruleKey] = rule
return nil
}
// genRuleSpec generates rule specification
func genRuleSpec(jump, source, destination string) []string {
return []string{"-s", source, "-d", destination, "-j", jump}

View File

@@ -95,7 +95,7 @@ func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddRoutingRules(pair)
return m.router.InsertRoutingRules(pair)
}
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {

View File

@@ -22,8 +22,6 @@ const (
userDataAcceptForwardRuleSrc = "frwacceptsrc"
userDataAcceptForwardRuleDst = "frwacceptdst"
loopbackInterface = "lo\x00"
)
// some presets for building nftable rules
@@ -128,22 +126,6 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeNAT,
})
// Add RETURN rule for loopback interface
loRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(loopbackInterface),
},
&expr.Verdict{Kind: expr.VerdictReturn},
},
}
r.conn.InsertRule(loRule)
err := r.refreshRulesMap()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
@@ -156,28 +138,28 @@ func (r *router) createContainers() error {
return nil
}
// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (r *router) AddRoutingRules(pair manager.RouterPair) error {
// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap()
if err != nil {
return err
}
err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
err = r.insertRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
if err != nil {
return err
}
err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
err = r.insertRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
if err != nil {
return err
}
if pair.Masquerade {
err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
err = r.insertRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
if err != nil {
return err
}
err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
err = r.insertRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
if err != nil {
return err
}
@@ -195,8 +177,8 @@ func (r *router) AddRoutingRules(pair manager.RouterPair) error {
return nil
}
// addRoutingRule inserts a nftable rule to the conn client flush queue
func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
// insertRoutingRule inserts a nftable rule to the conn client flush queue
func (r *router) insertRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
@@ -217,7 +199,7 @@ func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPai
}
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainName],
Exprs: expression,

View File

@@ -47,7 +47,7 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
require.NoError(t, err, "shouldn't return error")
err = manager.AddRoutingRules(testCase.InputPair)
err = manager.InsertRoutingRules(testCase.InputPair)
defer func() {
_ = manager.RemoveRoutingRules(testCase.InputPair)
}()

View File

@@ -78,11 +78,6 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}()
log.WithField("question", r.Question[0]).Trace("received an upstream question")
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
if r.Extra == nil {
r.SetEdns0(4096, false)
r.MsgHdr.AuthenticatedData = true
}
select {
case <-u.ctx.Done():

View File

@@ -1465,15 +1465,6 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs)
}
func (e *Engine) restartEngine() {
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
if err := e.Start(); err != nil {
log.Errorf("Failed to start engine: %v", err)
}
}
func (e *Engine) startNetworkMonitor() {
if !e.config.NetworkMonitor {
log.Infof("Network monitor is disabled, not starting")
@@ -1482,29 +1473,14 @@ func (e *Engine) startNetworkMonitor() {
e.networkMonitor = networkmonitor.New()
go func() {
var mu sync.Mutex
var debounceTimer *time.Timer
// Start the network monitor with a callback, Start will block until the monitor is stopped,
// a network change is detected, or an error occurs on start up
err := e.networkMonitor.Start(e.ctx, func() {
// This function is called when a network change is detected
mu.Lock()
defer mu.Unlock()
if debounceTimer != nil {
debounceTimer.Stop()
log.Infof("Network monitor detected network change, restarting engine")
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
if err := e.Start(); err != nil {
log.Errorf("Failed to start engine: %v", err)
}
// Set a new timer to debounce rapid network changes
debounceTimer = time.AfterFunc(1*time.Second, func() {
// This function is called after the debounce period
mu.Lock()
defer mu.Unlock()
log.Infof("Network monitor detected network change, restarting engine")
e.restartEngine()
})
})
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
log.Errorf("Network monitor: %v", err)

View File

@@ -174,7 +174,7 @@ func TestEngine_SSH(t *testing.T) {
t.Fatal(err)
}
// time.Sleep(250 * time.Millisecond)
//time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=")
@@ -1057,7 +1057,7 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
store, cleanUp, err := server.NewTestStoreFromJson(config.Datadir)
if err != nil {
return nil, "", err
}
@@ -1068,13 +1068,13 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
if err != nil {
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
ia, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
if err != nil {
return nil, "", err
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
return nil, "", err
}

View File

@@ -45,6 +45,24 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
switch msg.Type {
// handle interface state changes
case unix.RTM_IFINFO:
ifinfo, err := parseInterfaceMessage(buf[:n])
if err != nil {
log.Errorf("Network monitor: error parsing interface message: %v", err)
continue
}
if msg.Flags&unix.IFF_UP != 0 {
continue
}
if (nexthopv4.Intf == nil || ifinfo.Index != nexthopv4.Intf.Index) && (nexthopv6.Intf == nil || ifinfo.Index != nexthopv6.Intf.Index) {
continue
}
log.Infof("Network monitor: monitored interface (%s) is down.", ifinfo.Name)
go callback()
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
@@ -76,6 +94,24 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
}
}
func parseInterfaceMessage(buf []byte) (*route.InterfaceMessage, error) {
msgs, err := route.ParseRIB(route.RIBTypeInterface, buf)
if err != nil {
return nil, fmt.Errorf("parse RIB: %v", err)
}
if len(msgs) != 1 {
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
}
msg, ok := msgs[0].(*route.InterfaceMessage)
if !ok {
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
}
return msg, nil
}
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
if err != nil {

View File

@@ -19,9 +19,14 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
return errors.New("no interfaces available")
}
linkChan := make(chan netlink.LinkUpdate)
done := make(chan struct{})
defer close(done)
if err := netlink.LinkSubscribe(linkChan, done); err != nil {
return fmt.Errorf("subscribe to link updates: %v", err)
}
routeChan := make(chan netlink.RouteUpdate)
if err := netlink.RouteSubscribe(routeChan, done); err != nil {
return fmt.Errorf("subscribe to route updates: %v", err)
@@ -33,6 +38,25 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
case <-ctx.Done():
return ErrStopped
// handle interface state changes
case update := <-linkChan:
if (nexthopv4.Intf == nil || update.Index != int32(nexthopv4.Intf.Index)) && (nexthopv6.Intf == nil || update.Index != int32(nexthopv6.Intf.Index)) {
continue
}
switch update.Header.Type {
case syscall.RTM_DELLINK:
log.Infof("Network monitor: monitored interface (%s) is gone", update.Link.Attrs().Name)
go callback()
return nil
case syscall.RTM_NEWLINK:
if (update.IfInfomsg.Flags&syscall.IFF_RUNNING) == 0 && update.Link.Attrs().OperState == netlink.OperDown {
log.Infof("Network monitor: monitored interface (%s) is down.", update.Link.Attrs().Name)
go callback()
return nil
}
}
// handle route changes
case route := <-routeChan:
// default route and main table

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"net"
"net/netip"
"strings"
"time"
log "github.com/sirupsen/logrus"
@@ -34,8 +33,12 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
return fmt.Errorf("get neighbors: %w", err)
}
neighborv4 = assignNeighbor(nexthopv4, initialNeighbors)
neighborv6 = assignNeighbor(nexthopv6, initialNeighbors)
if n, ok := initialNeighbors[nexthopv4.IP]; ok {
neighborv4 = &n
}
if n, ok := initialNeighbors[nexthopv6.IP]; ok {
neighborv6 = &n
}
}
log.Debugf("Network monitor: initial IPv4 neighbor: %v, IPv6 neighbor: %v", neighborv4, neighborv6)
@@ -55,16 +58,6 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
}
}
func assignNeighbor(nexthop systemops.Nexthop, initialNeighbors map[netip.Addr]systemops.Neighbor) *systemops.Neighbor {
if n, ok := initialNeighbors[nexthop.IP]; ok &&
n.State != unreachable &&
n.State != incomplete &&
n.State != tbd {
return &n
}
return nil
}
func changed(
nexthopv4 systemops.Nexthop,
neighborv4 *systemops.Neighbor,
@@ -94,64 +87,37 @@ func changed(
}
// routeChanged checks if the default routes still point to our nexthop/interface
func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route) bool {
func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes map[netip.Prefix]systemops.Route) bool {
if !nexthop.IP.IsValid() {
return false
}
unspec := getUnspecifiedPrefix(nexthop.IP)
defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec)
var unspec netip.Prefix
if nexthop.IP.Is6() {
unspec = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
} else {
unspec = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
}
log.Tracef("network monitor: all default routes:\n%s", strings.Join(defaultRoutes, "\n"))
if !foundMatchingRoute {
logRouteChange(nexthop.IP, intf)
if r, ok := routes[unspec]; ok {
if r.Nexthop != nexthop.IP || compareIntf(r.Interface, intf) != 0 {
oldIntf, newIntf := "<nil>", "<nil>"
if intf != nil {
oldIntf = intf.Name
}
if r.Interface != nil {
newIntf = r.Interface.Name
}
log.Infof("network monitor: default route changed: %s from %s (%s) to %s (%s)", r.Destination, nexthop.IP, oldIntf, r.Nexthop, newIntf)
return true
}
} else {
log.Infof("network monitor: default route is gone")
return true
}
return false
}
func getUnspecifiedPrefix(ip netip.Addr) netip.Prefix {
if ip.Is6() {
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
}
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
}
func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) {
var defaultRoutes []string
foundMatchingRoute := false
for _, r := range routes {
if r.Destination == unspec {
routeInfo := formatRouteInfo(r)
defaultRoutes = append(defaultRoutes, routeInfo)
if r.Nexthop == nexthop.IP && compareIntf(r.Interface, intf) == 0 {
foundMatchingRoute = true
log.Debugf("network monitor: found matching default route: %s", routeInfo)
}
}
}
return defaultRoutes, foundMatchingRoute
}
func formatRouteInfo(r systemops.Route) string {
newIntf := "<nil>"
if r.Interface != nil {
newIntf = r.Interface.Name
}
return fmt.Sprintf("Nexthop: %s, Interface: %s", r.Nexthop, newIntf)
}
func logRouteChange(ip netip.Addr, intf *net.Interface) {
oldIntf := "<nil>"
if intf != nil {
oldIntf = intf.Name
}
log.Infof("network monitor: default route for %s (%s) is gone or changed", ip, oldIntf)
}
func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, neighbors map[netip.Addr]systemops.Neighbor) bool {
@@ -161,7 +127,7 @@ func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, ne
// TODO: consider non-local nexthops, e.g. on point-to-point interfaces
if n, ok := neighbors[nexthop.IP]; ok {
if n.State == unreachable || n.State == incomplete {
if n.State != reachable && n.State != permanent {
log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State))
return true
} else if n.InterfaceIndex != neighbor.InterfaceIndex {
@@ -199,13 +165,18 @@ func getNeighbors() (map[netip.Addr]systemops.Neighbor, error) {
return neighbours, nil
}
func getRoutes() ([]systemops.Route, error) {
func getRoutes() (map[netip.Prefix]systemops.Route, error) {
entries, err := systemops.GetRoutes()
if err != nil {
return nil, fmt.Errorf("get routes: %w", err)
}
return entries, nil
routes := make(map[netip.Prefix]systemops.Route, len(entries))
for _, entry := range entries {
routes[entry.Destination] = entry
}
return routes, nil
}
func stateFromInt(state uint8) string {

View File

@@ -108,7 +108,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
store, cleanUp, err := server.NewTestStoreFromJson(config.Datadir)
if err != nil {
return nil, "", err
}
@@ -119,13 +119,13 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
if err != nil {
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
ia, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
if err != nil {
return nil, "", err
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
return nil, "", err
}

View File

@@ -89,9 +89,5 @@ func _getInfo() string {
func sysInfo() (serialNumber string, productName string, manufacturer string) {
var si sysinfo.SysInfo
si.GetSysInfo()
serial := si.Chassis.Serial
if (serial == "Default string" || serial == "") && si.Product.Serial != "" {
serial = si.Product.Serial
}
return serial, si.Product.Name, si.Product.Vendor
return si.Chassis.Serial, si.Product.Name, si.Product.Vendor
}

View File

@@ -80,7 +80,6 @@ func main() {
log.Errorf("check PID file: %v", err)
return
}
client.setDefaultFonts()
systray.Run(client.onTrayReady, client.onTrayExit)
}
}
@@ -877,88 +876,3 @@ func checkPIDFile() error {
return os.WriteFile(pidFile, []byte(fmt.Sprintf("%d", os.Getpid())), 0o664) //nolint:gosec
}
func (s *serviceClient) setDefaultFonts() {
var (
defaultFontPath string
)
//TODO: Linux Multiple Language Support
switch runtime.GOOS {
case "darwin":
defaultFontPath = "/Library/Fonts/Arial Unicode.ttf"
case "windows":
fontPath := s.getWindowsFontFilePath()
defaultFontPath = fontPath
}
_, err := os.Stat(defaultFontPath)
if err == nil {
os.Setenv("FYNE_FONT", defaultFontPath)
}
}
func (s *serviceClient) getWindowsFontFilePath() (fontPath string) {
/*
https://learn.microsoft.com/en-us/windows/apps/design/globalizing/loc-international-fonts
https://learn.microsoft.com/en-us/typography/fonts/windows_11_font_list
*/
var (
fontFolder string = "C:/Windows/Fonts"
fontMapping = map[string]string{
"default": "Segoeui.ttf",
"zh-CN": "Msyh.ttc",
"am-ET": "Ebrima.ttf",
"nirmala": "Nirmala.ttf",
"chr-CHER-US": "Gadugi.ttf",
"zh-HK": "Msjh.ttc",
"zh-TW": "Msjh.ttc",
"ja-JP": "Yugothm.ttc",
"km-KH": "Leelawui.ttf",
"ko-KR": "Malgun.ttf",
"th-TH": "Leelawui.ttf",
"ti-ET": "Ebrima.ttf",
}
nirMalaLang = []string{
"as-IN",
"bn-BD",
"bn-IN",
"gu-IN",
"hi-IN",
"kn-IN",
"kok-IN",
"ml-IN",
"mr-IN",
"ne-NP",
"or-IN",
"pa-IN",
"si-LK",
"ta-IN",
"te-IN",
}
)
cmd := exec.Command("powershell", "-Command", "(Get-Culture).Name")
output, err := cmd.Output()
if err != nil {
log.Errorf("Failed to get Windows default language setting: %v", err)
fontPath = path.Join(fontFolder, fontMapping["default"])
return
}
defaultLanguage := strings.TrimSpace(string(output))
for _, lang := range nirMalaLang {
if defaultLanguage == lang {
fontPath = path.Join(fontFolder, fontMapping["nirmala"])
return
}
}
if font, ok := fontMapping[defaultLanguage]; ok {
fontPath = path.Join(fontFolder, font)
} else {
fontPath = path.Join(fontFolder, fontMapping["default"])
}
return
}

View File

@@ -4,7 +4,6 @@ package main
import (
"fmt"
"sort"
"strings"
"time"
@@ -18,57 +17,28 @@ import (
"github.com/netbirdio/netbird/client/proto"
)
const (
allRoutesText = "All routes"
overlappingRoutesText = "Overlapping routes"
exitNodeRoutesText = "Exit-node routes"
allRoutes filter = "all"
overlappingRoutes filter = "overlapping"
exitNodeRoutes filter = "exit-node"
getClientFMT = "get client: %v"
)
type filter string
func (s *serviceClient) showRoutesUI() {
s.wRoutes = s.app.NewWindow("NetBird Routes")
allGrid := container.New(layout.NewGridLayout(3))
go s.updateRoutes(allGrid, allRoutes)
overlappingGrid := container.New(layout.NewGridLayout(3))
exitNodeGrid := container.New(layout.NewGridLayout(3))
grid := container.New(layout.NewGridLayout(3))
go s.updateRoutes(grid)
routeCheckContainer := container.NewVBox()
tabs := container.NewAppTabs(
container.NewTabItem(allRoutesText, allGrid),
container.NewTabItem(overlappingRoutesText, overlappingGrid),
container.NewTabItem(exitNodeRoutesText, exitNodeGrid),
)
tabs.OnSelected = func(item *container.TabItem) {
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
}
tabs.OnUnselected = func(item *container.TabItem) {
grid, _ := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
grid.Objects = nil
}
routeCheckContainer.Add(tabs)
routeCheckContainer.Add(grid)
scrollContainer := container.NewVScroll(routeCheckContainer)
scrollContainer.SetMinSize(fyne.NewSize(200, 300))
buttonBox := container.NewHBox(
layout.NewSpacer(),
widget.NewButton("Refresh", func() {
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
s.updateRoutes(grid)
}),
widget.NewButton("Select all", func() {
_, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
s.selectAllFilteredRoutes(f)
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
s.selectAllRoutes()
s.updateRoutes(grid)
}),
widget.NewButton("Deselect All", func() {
_, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
s.deselectAllFilteredRoutes(f)
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
s.deselectAllRoutes()
s.updateRoutes(grid)
}),
layout.NewSpacer(),
)
@@ -78,12 +48,18 @@ func (s *serviceClient) showRoutesUI() {
s.wRoutes.SetContent(content)
s.wRoutes.Show()
s.startAutoRefresh(10*time.Second, tabs, allGrid, overlappingGrid, exitNodeGrid)
s.startAutoRefresh(5*time.Second, grid)
}
func (s *serviceClient) updateRoutes(grid *fyne.Container, f filter) {
func (s *serviceClient) updateRoutes(grid *fyne.Container) {
routes, err := s.fetchRoutes()
if err != nil {
log.Errorf("get client: %v", err)
s.showError(fmt.Errorf("get client: %v", err))
return
}
grid.Objects = nil
grid.Refresh()
idHeader := widget.NewLabelWithStyle(" ID", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
networkHeader := widget.NewLabelWithStyle("Network/Domains", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
resolvedIPsHeader := widget.NewLabelWithStyle("Resolved IPs", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
@@ -91,15 +67,7 @@ func (s *serviceClient) updateRoutes(grid *fyne.Container, f filter) {
grid.Add(idHeader)
grid.Add(networkHeader)
grid.Add(resolvedIPsHeader)
filteredRoutes, err := s.getFilteredRoutes(f)
if err != nil {
return
}
sortRoutesByIDs(filteredRoutes)
for _, route := range filteredRoutes {
for _, route := range routes {
r := route
checkBox := widget.NewCheck(r.GetID(), func(checked bool) {
@@ -112,104 +80,35 @@ func (s *serviceClient) updateRoutes(grid *fyne.Container, f filter) {
grid.Add(checkBox)
network := r.GetNetwork()
domains := r.GetDomains()
if len(domains) == 0 {
grid.Add(widget.NewLabel(network))
grid.Add(widget.NewLabel(""))
continue
if len(domains) > 0 {
network = strings.Join(domains, ", ")
}
grid.Add(widget.NewLabel(network))
// our selectors are only for display
noopFunc := func(_ string) {
// do nothing
}
domainsSelector := widget.NewSelect(domains, noopFunc)
domainsSelector.Selected = domains[0]
grid.Add(domainsSelector)
var resolvedIPsList []string
for _, domain := range domains {
if ipList, exists := r.GetResolvedIPs()[domain]; exists {
resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", ")))
if len(domains) > 0 {
var resolvedIPsList []string
for _, domain := range r.GetDomains() {
if ipList, exists := r.GetResolvedIPs()[domain]; exists {
resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", ")))
}
}
}
if len(resolvedIPsList) == 0 {
// TODO: limit width
resolvedIPsLabel := widget.NewLabel(strings.Join(resolvedIPsList, ", "))
grid.Add(resolvedIPsLabel)
} else {
grid.Add(widget.NewLabel(""))
continue
}
// TODO: limit width within the selector display
resolvedIPsSelector := widget.NewSelect(resolvedIPsList, noopFunc)
resolvedIPsSelector.Selected = resolvedIPsList[0]
resolvedIPsSelector.Resize(fyne.NewSize(100, 100))
grid.Add(resolvedIPsSelector)
}
}
s.wRoutes.Content().Refresh()
grid.Refresh()
}
func (s *serviceClient) getFilteredRoutes(f filter) ([]*proto.Route, error) {
routes, err := s.fetchRoutes()
if err != nil {
log.Errorf(getClientFMT, err)
s.showError(fmt.Errorf(getClientFMT, err))
return nil, err
}
switch f {
case overlappingRoutes:
return getOverlappingRoutes(routes), nil
case exitNodeRoutes:
return getExitNodeRoutes(routes), nil
default:
}
return routes, nil
}
func getOverlappingRoutes(routes []*proto.Route) []*proto.Route {
var filteredRoutes []*proto.Route
existingRange := make(map[string][]*proto.Route)
for _, route := range routes {
if len(route.Domains) > 0 {
continue
}
if r, exists := existingRange[route.GetNetwork()]; exists {
r = append(r, route)
existingRange[route.GetNetwork()] = r
} else {
existingRange[route.GetNetwork()] = []*proto.Route{route}
}
}
for _, r := range existingRange {
if len(r) > 1 {
filteredRoutes = append(filteredRoutes, r...)
}
}
return filteredRoutes
}
func getExitNodeRoutes(routes []*proto.Route) []*proto.Route {
var filteredRoutes []*proto.Route
for _, route := range routes {
if route.Network == "0.0.0.0/0" {
filteredRoutes = append(filteredRoutes, route)
}
}
return filteredRoutes
}
func sortRoutesByIDs(routes []*proto.Route) {
sort.Slice(routes, func(i, j int) bool {
return strings.ToLower(routes[i].GetID()) < strings.ToLower(routes[j].GetID())
})
}
func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
return nil, fmt.Errorf(getClientFMT, err)
return nil, fmt.Errorf("get client: %v", err)
}
resp, err := conn.ListRoutes(s.ctx, &proto.ListRoutesRequest{})
@@ -223,8 +122,8 @@ func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) {
func (s *serviceClient) selectRoute(id string, checked bool) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Errorf(getClientFMT, err)
s.showError(fmt.Errorf(getClientFMT, err))
log.Errorf("get client: %v", err)
s.showError(fmt.Errorf("get client: %v", err))
return
}
@@ -250,14 +149,16 @@ func (s *serviceClient) selectRoute(id string, checked bool) {
}
}
func (s *serviceClient) selectAllFilteredRoutes(f filter) {
func (s *serviceClient) selectAllRoutes() {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Errorf(getClientFMT, err)
log.Errorf("get client: %v", err)
return
}
req := s.getRoutesRequest(f, true)
req := &proto.SelectRoutesRequest{
All: true,
}
if _, err := conn.SelectRoutes(s.ctx, req); err != nil {
log.Errorf("failed to select all routes: %v", err)
s.showError(fmt.Errorf("failed to select all routes: %v", err))
@@ -267,14 +168,16 @@ func (s *serviceClient) selectAllFilteredRoutes(f filter) {
log.Debug("All routes selected")
}
func (s *serviceClient) deselectAllFilteredRoutes(f filter) {
func (s *serviceClient) deselectAllRoutes() {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Errorf(getClientFMT, err)
log.Errorf("get client: %v", err)
return
}
req := s.getRoutesRequest(f, false)
req := &proto.SelectRoutesRequest{
All: true,
}
if _, err := conn.DeselectRoutes(s.ctx, req); err != nil {
log.Errorf("failed to deselect all routes: %v", err)
s.showError(fmt.Errorf("failed to deselect all routes: %v", err))
@@ -284,34 +187,17 @@ func (s *serviceClient) deselectAllFilteredRoutes(f filter) {
log.Debug("All routes deselected")
}
func (s *serviceClient) getRoutesRequest(f filter, appendRoute bool) *proto.SelectRoutesRequest {
req := &proto.SelectRoutesRequest{}
if f == allRoutes {
req.All = true
} else {
routes, err := s.getFilteredRoutes(f)
if err != nil {
return nil
}
for _, route := range routes {
req.RouteIDs = append(req.RouteIDs, route.GetID())
}
req.Append = appendRoute
}
return req
}
func (s *serviceClient) showError(err error) {
wrappedMessage := wrapText(err.Error(), 50)
dialog.ShowError(fmt.Errorf("%s", wrappedMessage), s.wRoutes)
}
func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) {
func (s *serviceClient) startAutoRefresh(interval time.Duration, grid *fyne.Container) {
ticker := time.NewTicker(interval)
go func() {
for range ticker.C {
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodesGrid)
s.updateRoutes(grid)
}
}()
@@ -320,23 +206,6 @@ func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container
})
}
func (s *serviceClient) updateRoutesBasedOnDisplayTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) {
grid, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodesGrid)
s.wRoutes.Content().Refresh()
s.updateRoutes(grid, f)
}
func getGridAndFilterFromTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) (*fyne.Container, filter) {
switch tabs.Selected().Text {
case overlappingRoutesText:
return overlappingGrid, overlappingRoutes
case exitNodeRoutesText:
return exitNodesGrid, exitNodeRoutes
default:
return allGrid, allRoutes
}
}
// wrapText inserts newlines into the text to ensure that each line is
// no longer than 'lineLength' runes.
func wrapText(text string, lineLength int) string {

View File

@@ -7,18 +7,6 @@ import (
"strings"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/context"
)
type ExecutionContext string
const (
ExecutionContextKey = "executionContext"
HTTPSource ExecutionContext = "HTTP"
GRPCSource ExecutionContext = "GRPC"
SystemSource ExecutionContext = "SYSTEM"
)
// ContextHook is a custom hook for add the source information for the entry
@@ -42,27 +30,6 @@ func (hook ContextHook) Levels() []logrus.Level {
func (hook ContextHook) Fire(entry *logrus.Entry) error {
src := hook.parseSrc(entry.Caller.File)
entry.Data["source"] = fmt.Sprintf("%s:%v", src, entry.Caller.Line)
if entry.Context == nil {
return nil
}
source, ok := entry.Context.Value(ExecutionContextKey).(ExecutionContext)
if !ok {
return nil
}
entry.Data["context"] = source
switch source {
case HTTPSource:
addHTTPFields(entry)
case GRPCSource:
addGRPCFields(entry)
case SystemSource:
addSystemFields(entry)
}
return nil
}
@@ -92,42 +59,3 @@ func (hook ContextHook) parseSrc(filePath string) string {
file := path.Base(filePath)
return fmt.Sprintf("%s/%s", pkg, file)
}
func addHTTPFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID
}
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
entry.Data[context.AccountIDKey] = ctxAccountID
}
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
entry.Data[context.UserIDKey] = ctxInitiatorID
}
}
func addGRPCFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID
}
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
entry.Data[context.AccountIDKey] = ctxAccountID
}
if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok {
entry.Data[context.PeerIDKey] = ctxDeviceID
}
}
func addSystemFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID
}
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
entry.Data[context.UserIDKey] = ctxInitiatorID
}
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
entry.Data[context.AccountIDKey] = ctxAccountID
}
if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok {
entry.Data[context.PeerIDKey] = ctxDeviceID
}
}

View File

@@ -1,8 +1,6 @@
package formatter
import (
"github.com/sirupsen/logrus"
)
import "github.com/sirupsen/logrus"
// SetTextFormatter set the text formatter for given logger.
func SetTextFormatter(logger *logrus.Logger) {
@@ -11,13 +9,6 @@ func SetTextFormatter(logger *logrus.Logger) {
logger.AddHook(NewContextHook())
}
// SetJSONFormatter set the JSON formatter for given logger.
func SetJSONFormatter(logger *logrus.Logger) {
logger.Formatter = &logrus.JSONFormatter{}
logger.ReportCaller = true
logger.AddHook(NewContextHook())
}
// SetLogcatFormatter set the logcat formatter for given logger.
func SetLogcatFormatter(logger *logrus.Logger) {
logger.Formatter = NewLogcatFormatter()

8
go.mod
View File

@@ -44,6 +44,7 @@ require (
github.com/golang/mock v1.6.0
github.com/google/go-cmp v0.6.0
github.com/google/gopacket v1.1.19
github.com/google/martian/v3 v3.0.0
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
github.com/gopacket/gopacket v1.1.1
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
@@ -57,7 +58,7 @@ require (
github.com/miekg/dns v1.1.43
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e
github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
@@ -66,6 +67,7 @@ require (
github.com/pion/transport/v3 v3.0.1
github.com/pion/turn/v3 v3.0.1
github.com/prometheus/client_golang v1.19.1
github.com/r3labs/diff v1.1.0
github.com/rs/xid v1.3.0
github.com/shirou/gopsutil/v3 v3.24.4
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
@@ -188,8 +190,8 @@ require (
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
go.opentelemetry.io/otel/sdk v1.26.0 // indirect
go.opentelemetry.io/otel/trace v1.26.0 // indirect
golang.org/x/image v0.18.0 // indirect
golang.org/x/text v0.16.0 // indirect
golang.org/x/image v0.10.0 // indirect
golang.org/x/text v0.15.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect

18
go.sum
View File

@@ -209,6 +209,8 @@ github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSN
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
github.com/google/martian/v3 v3.0.0 h1:pMen7vLs8nvgEYhywH3KDWJIJTeEr2ULsVWHWYHQyBs=
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A=
github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc=
github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o=
@@ -333,8 +335,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd h1:IzGGIJMpz07aPs3R6/4sxZv63JoCMddftLpVodUK+Ec=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=
@@ -413,6 +415,8 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek=
github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk=
github.com/r3labs/diff v1.1.0 h1:V53xhrbTHrWFWq3gI4b94AjgEJOerO1+1l0xyHOBi8M=
github.com/r3labs/diff v1.1.0/go.mod h1:7WjXasNzi0vJetRcB/RqNl5dlIsmXcTTLmF5IoH6Xig=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so=
@@ -540,8 +544,8 @@ golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJ
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20200430140353-33d19683fad8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ=
golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E=
golang.org/x/image v0.10.0 h1:gXjUUtwtx5yOE0VKWq1CH4IJAClq4UGgUA3i+rpON9M=
golang.org/x/image v0.10.0/go.mod h1:jtrku+n79PfroUbvDdeUWMAI+heR786BofxrbiSF+J0=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
@@ -563,6 +567,7 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191004110552-13f9640d40b9/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
@@ -652,10 +657,11 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@@ -28,11 +28,7 @@ services:
- LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL
volumes:
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt/
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Signal
signal:
image: netbirdio/signal:$NETBIRD_SIGNAL_TAG
@@ -44,11 +40,6 @@ services:
# # port and command for Let's Encrypt validation
# - 443:443
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Management
management:
@@ -72,16 +63,12 @@ services:
"--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN",
"--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"
]
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Coturn
coturn:
image: coturn/coturn:$COTURN_TAG
restart: unless-stopped
#domainname: $TURN_DOMAIN # only needed when TLS is enabled
domainname: $TURN_DOMAIN
volumes:
- ./turnserver.conf:/etc/turnserver.conf:ro
# - ./privkey.pem:/etc/coturn/private/privkey.pem:ro
@@ -89,11 +76,7 @@ services:
network_mode: host
command:
- -c /etc/turnserver.conf
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
volumes:
$MGMT_VOLUMENAME:
$SIGNAL_VOLUMENAME:

View File

@@ -50,7 +50,7 @@ check_jq() {
wait_crdb() {
set +e
while true; do
if $DOCKER_COMPOSE_COMMAND exec -T zdb curl -sf -o /dev/null 'http://localhost:8080/health?ready=1'; then
if $DOCKER_COMPOSE_COMMAND exec -T crdb curl -sf -o /dev/null 'http://localhost:8080/health?ready=1'; then
break
fi
echo -n " ."
@@ -61,16 +61,14 @@ wait_crdb() {
}
init_crdb() {
if [[ $ZITADEL_DATABASE == "cockroach" ]]; then
echo -e "\nInitializing Zitadel's CockroachDB\n\n"
$DOCKER_COMPOSE_COMMAND up -d zdb
echo ""
# shellcheck disable=SC2028
echo -n "Waiting CockroachDB to become ready"
wait_crdb
$DOCKER_COMPOSE_COMMAND exec -T zdb /bin/bash -c "cp /cockroach/certs/* /zitadel-certs/ && cockroach cert create-client --overwrite --certs-dir /zitadel-certs/ --ca-key /zitadel-certs/ca.key zitadel_user && chown -R 1000:1000 /zitadel-certs/"
handle_request_command_status $? "init_crdb failed" ""
fi
echo -e "\nInitializing Zitadel's CockroachDB\n\n"
$DOCKER_COMPOSE_COMMAND up -d crdb
echo ""
# shellcheck disable=SC2028
echo -n "Waiting cockroachDB to become ready "
wait_crdb
$DOCKER_COMPOSE_COMMAND exec -T crdb /bin/bash -c "cp /cockroach/certs/* /zitadel-certs/ && cockroach cert create-client --overwrite --certs-dir /zitadel-certs/ --ca-key /zitadel-certs/ca.key zitadel_user && chown -R 1000:1000 /zitadel-certs/"
handle_request_command_status $? "init_crdb failed" ""
}
get_main_ip_address() {
@@ -158,7 +156,7 @@ create_new_application() {
"'"$BASE_REDIRECT_URL2"'"
],
"postLogoutRedirectUris": [
"'"$LOGOUT_URL"'"
"'"$LOGOUT_URL"'"
],
"RESPONSETypes": [
"OIDC_RESPONSE_TYPE_CODE"
@@ -463,20 +461,6 @@ initEnvironment() {
exit 1
fi
if [[ $ZITADEL_DATABASE == "cockroach" ]]; then
echo "Use CockroachDB as Zitadel database."
ZDB=$(renderDockerComposeCockroachDB)
ZITADEL_DB_ENV=$(renderZitadelCockroachDBEnv)
else
echo "Use Postgres as default Zitadel database."
echo "For using CockroachDB please the environment variable 'export ZITADEL_DATABASE=cockroach'."
POSTGRES_ROOT_PASSWORD="$(openssl rand -base64 32 | sed 's/=//g')@"
POSTGRES_ZITADEL_PASSWORD="$(openssl rand -base64 32 | sed 's/=//g')@"
ZDB=$(renderDockerComposePostgres)
ZITADEL_DB_ENV=$(renderZitadelPostgresEnv)
renderPostgresEnv > zdb.env
fi
echo Rendering initial files...
renderDockerCompose > docker-compose.yml
renderCaddyfile > Caddyfile
@@ -490,7 +474,7 @@ initEnvironment() {
init_crdb
echo -e "\nStarting Zitadel IDP for user management\n\n"
echo -e "\nStarting Zidatel IDP for user management\n\n"
$DOCKER_COMPOSE_COMMAND up -d caddy zitadel
init_zitadel
@@ -650,15 +634,15 @@ renderManagementJson() {
"ExtraConfig": {
"ManagementEndpoint": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/management/v1"
}
},
"DeviceAuthorizationFlow": {
"Provider": "hosted",
"ProviderConfig": {
"Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI",
"ClientID": "$NETBIRD_AUTH_CLIENT_ID_CLI",
"Scope": "openid"
}
},
},
"DeviceAuthorizationFlow": {
"Provider": "hosted",
"ProviderConfig": {
"Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI",
"ClientID": "$NETBIRD_AUTH_CLIENT_ID_CLI",
"Scope": "openid"
}
},
"PKCEAuthorizationFlow": {
"ProviderConfig": {
"Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI",
@@ -695,6 +679,16 @@ renderZitadelEnv() {
cat <<EOF
ZITADEL_LOG_LEVEL=debug
ZITADEL_MASTERKEY=$ZITADEL_MASTERKEY
ZITADEL_DATABASE_COCKROACH_HOST=crdb
ZITADEL_DATABASE_COCKROACH_USER_USERNAME=zitadel_user
ZITADEL_DATABASE_COCKROACH_USER_SSL_MODE=verify-full
ZITADEL_DATABASE_COCKROACH_USER_SSL_ROOTCERT="/crdb-certs/ca.crt"
ZITADEL_DATABASE_COCKROACH_USER_SSL_CERT="/crdb-certs/client.zitadel_user.crt"
ZITADEL_DATABASE_COCKROACH_USER_SSL_KEY="/crdb-certs/client.zitadel_user.key"
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_MODE=verify-full
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_ROOTCERT="/crdb-certs/ca.crt"
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_CERT="/crdb-certs/client.root.crt"
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_KEY="/crdb-certs/client.root.key"
ZITADEL_EXTERNALSECURE=$ZITADEL_EXTERNALSECURE
ZITADEL_TLS_ENABLED="false"
ZITADEL_EXTERNALPORT=$NETBIRD_PORT
@@ -704,43 +698,6 @@ ZITADEL_FIRSTINSTANCE_ORG_MACHINE_MACHINE_USERNAME=zitadel-admin-sa
ZITADEL_FIRSTINSTANCE_ORG_MACHINE_MACHINE_NAME=Admin
ZITADEL_FIRSTINSTANCE_ORG_MACHINE_PAT_SCOPES=openid
ZITADEL_FIRSTINSTANCE_ORG_MACHINE_PAT_EXPIRATIONDATE=$ZIDATE_TOKEN_EXPIRATION_DATE
$ZITADEL_DB_ENV
EOF
}
renderZitadelCockroachDBEnv() {
cat <<EOF
ZITADEL_DATABASE_COCKROACH_HOST=zdb
ZITADEL_DATABASE_COCKROACH_USER_USERNAME=zitadel_user
ZITADEL_DATABASE_COCKROACH_USER_SSL_MODE=verify-full
ZITADEL_DATABASE_COCKROACH_USER_SSL_ROOTCERT="/zdb-certs/ca.crt"
ZITADEL_DATABASE_COCKROACH_USER_SSL_CERT="/zdb-certs/client.zitadel_user.crt"
ZITADEL_DATABASE_COCKROACH_USER_SSL_KEY="/zdb-certs/client.zitadel_user.key"
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_MODE=verify-full
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_ROOTCERT="/zdb-certs/ca.crt"
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_CERT="/zdb-certs/client.root.crt"
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_KEY="/zdb-certs/client.root.key"
EOF
}
renderZitadelPostgresEnv() {
cat <<EOF
ZITADEL_DATABASE_POSTGRES_HOST=zdb
ZITADEL_DATABASE_POSTGRES_PORT=5432
ZITADEL_DATABASE_POSTGRES_DATABASE=zitadel
ZITADEL_DATABASE_POSTGRES_USER_USERNAME=zitadel
ZITADEL_DATABASE_POSTGRES_USER_PASSWORD=$POSTGRES_ZITADEL_PASSWORD
ZITADEL_DATABASE_POSTGRES_USER_SSL_MODE=disable
ZITADEL_DATABASE_POSTGRES_ADMIN_USERNAME=root
ZITADEL_DATABASE_POSTGRES_ADMIN_PASSWORD=$POSTGRES_ROOT_PASSWORD
ZITADEL_DATABASE_POSTGRES_ADMIN_SSL_MODE=disable
EOF
}
renderPostgresEnv() {
cat <<EOF
POSTGRES_USER=root
POSTGRES_PASSWORD=$POSTGRES_ROOT_PASSWORD
EOF
}
@@ -767,21 +724,11 @@ services:
networks: [netbird]
env_file:
- ./dashboard.env
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Signal
signal:
image: netbirdio/signal:latest
restart: unless-stopped
networks: [netbird]
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Management
management:
image: netbirdio/management:latest
@@ -799,26 +746,16 @@ services:
"--dns-domain=netbird.selfhosted",
"--idp-sign-key-refresh-enabled",
]
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Coturn, AKA relay server
coturn:
image: coturn/coturn
restart: unless-stopped
#domainname: netbird.relay.selfhosted
domainname: netbird.relay.selfhosted
volumes:
- ./turnserver.conf:/etc/turnserver.conf:ro
network_mode: host
command:
- -c /etc/turnserver.conf
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Zitadel - identity provider
zitadel:
restart: 'always'
@@ -828,38 +765,20 @@ services:
env_file:
- ./zitadel.env
depends_on:
zdb:
crdb:
condition: 'service_healthy'
volumes:
- ./machinekey:/machinekey
- netbird_zitadel_certs:/zdb-certs:ro
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
$ZDB
netbird_zdb_data:
netbird_management:
netbird_caddy_data:
netbird_zitadel_certs:
networks:
netbird:
EOF
}
renderDockerComposeCockroachDB() {
cat <<EOF
# CockroachDB for Zitadel
zdb:
- netbird_zitadel_certs:/crdb-certs:ro
# CockroachDB for zitadel
crdb:
restart: 'always'
networks: [netbird]
image: 'cockroachdb/cockroach:latest-v23.2'
command: 'start-single-node --advertise-addr zdb'
command: 'start-single-node --advertise-addr crdb'
volumes:
- netbird_zdb_data:/cockroach/cockroach-data
- netbird_zdb_certs:/cockroach/certs
- netbird_crdb_data:/cockroach/cockroach-data
- netbird_crdb_certs:/cockroach/certs
- netbird_zitadel_certs:/zitadel-certs
healthcheck:
test: [ "CMD", "curl", "-f", "http://localhost:8080/health?ready=1" ]
@@ -867,40 +786,16 @@ renderDockerComposeCockroachDB() {
timeout: '30s'
retries: 5
start_period: '20s'
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
volumes:
netbird_zdb_certs:
EOF
}
netbird_management:
netbird_caddy_data:
netbird_crdb_data:
netbird_crdb_certs:
netbird_zitadel_certs:
renderDockerComposePostgres() {
cat <<EOF
# Postgres for Zitadel
zdb:
restart: 'always'
networks: [netbird]
image: 'postgres:16-alpine'
env_file:
- ./zdb.env
volumes:
- netbird_zdb_data:/var/lib/postgresql/data:rw
healthcheck:
test: ["CMD-SHELL", "pg_isready", "-d", "db_prod"]
interval: 5s
timeout: 60s
retries: 10
start_period: 5s
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
volumes:
networks:
netbird:
EOF
}

View File

@@ -62,7 +62,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err)
}
s := grpc.NewServer()
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
store, cleanUp, err := mgmt.NewTestStoreFromJson(config.Datadir)
if err != nil {
t.Fatal(err)
}
@@ -70,13 +70,13 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
ia, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -20,7 +20,6 @@ import (
"time"
"github.com/google/uuid"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
@@ -36,10 +35,8 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter"
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/idp"
@@ -80,10 +77,6 @@ var (
Short: "start NetBird Management Server",
PreRunE: func(cmd *cobra.Command, args []string) error {
flag.Parse()
//nolint
ctx := context.WithValue(cmd.Context(), formatter.ExecutionContextKey, formatter.SystemSource)
err := util.InitLog(logLevel, logFile)
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
@@ -92,7 +85,7 @@ var (
// detect whether user specified a port
userPort := cmd.Flag("port").Changed
config, err = loadMgmtConfig(ctx, mgmtConfig)
config, err = loadMgmtConfig(mgmtConfig)
if err != nil {
return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err)
}
@@ -123,11 +116,6 @@ var (
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
//nolint
ctx = context.WithValue(ctx, formatter.ExecutionContextKey, formatter.SystemSource)
err := handleRebrand(cmd)
if err != nil {
return fmt.Errorf("failed to migrate files %v", err)
@@ -143,11 +131,11 @@ var (
if err != nil {
return err
}
err = appMetrics.Expose(ctx, mgmtMetricsPort, "/metrics")
err = appMetrics.Expose(mgmtMetricsPort, "/metrics")
if err != nil {
return err
}
store, err := server.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics)
store, err := server.NewStore(config.StoreConfig.Engine, config.Datadir, appMetrics)
if err != nil {
return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err)
}
@@ -155,7 +143,7 @@ var (
var idpManager idp.Manager
if config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(ctx, *config.IdpManagerConfig, appMetrics)
idpManager, err = idp.NewManager(*config.IdpManagerConfig, appMetrics)
if err != nil {
return fmt.Errorf("failed retrieving a new idp manager with err: %v", err)
}
@@ -164,32 +152,32 @@ var (
if disableSingleAccMode {
mgmtSingleAccModeDomain = ""
}
eventStore, key, err := integrations.InitEventStore(ctx, config.Datadir, config.DataStoreEncryptionKey)
eventStore, key, err := integrations.InitEventStore(config.Datadir, config.DataStoreEncryptionKey)
if err != nil {
return fmt.Errorf("failed to initialize database: %s", err)
}
if config.DataStoreEncryptionKey != key {
log.WithContext(ctx).Infof("update config with activity store key")
log.Infof("update config with activity store key")
config.DataStoreEncryptionKey = key
err := updateMgmtConfig(ctx, mgmtConfig, config)
err := updateMgmtConfig(mgmtConfig, config)
if err != nil {
return fmt.Errorf("failed to write out store encryption key: %s", err)
}
}
geo, err := geolocation.NewGeolocation(ctx, config.Datadir)
geo, err := geolocation.NewGeolocation(config.Datadir)
if err != nil {
log.WithContext(ctx).Warnf("could not initialize geo location service: %v, we proceed without geo support", err)
log.Warnf("could not initialize geo location service: %v, we proceed without geo support", err)
} else {
log.WithContext(ctx).Infof("geo location service has been initialized from %s", config.Datadir)
log.Infof("geo location service has been initialized from %s", config.Datadir)
}
integratedPeerValidator, err := integrations.NewIntegratedValidator(ctx, eventStore)
integratedPeerValidator, err := integrations.NewIntegratedValidator(eventStore)
if err != nil {
return fmt.Errorf("failed to initialize integrated peer validator: %v", err)
}
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator)
if err != nil {
return fmt.Errorf("failed to build default manager: %v", err)
@@ -200,13 +188,13 @@ var (
trustedPeers := config.ReverseProxy.TrustedPeers
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) {
log.WithContext(ctx).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
log.Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
trustedPeers = defaultTrustedPeers
}
trustedHTTPProxies := config.ReverseProxy.TrustedHTTPProxies
trustedProxiesCount := config.ReverseProxy.TrustedHTTPProxiesCount
if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 {
log.WithContext(ctx).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
log.Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
"This is not recommended way to extract X-Forwarded-For. Consider using one of these options.")
}
realipOpts := []realip.Option{
@@ -218,8 +206,8 @@ var (
gRPCOpts := []grpc.ServerOption{
grpc.KeepaliveEnforcementPolicy(kaep),
grpc.KeepaliveParams(kasp),
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor),
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...)),
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...)),
}
var certManager *autocert.Manager
@@ -236,7 +224,7 @@ var (
} else if config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "" {
tlsConfig, err = loadTLSConfig(config.HttpConfig.CertFile, config.HttpConfig.CertKey)
if err != nil {
log.WithContext(ctx).Errorf("cannot load TLS credentials: %v", err)
log.Errorf("cannot load TLS credentials: %v", err)
return err
}
transportCredentials := credentials.NewTLS(tlsConfig)
@@ -245,7 +233,6 @@ var (
}
jwtValidator, err := jwtclaims.NewJWTValidator(
ctx,
config.HttpConfig.AuthIssuer,
config.GetAuthAudiences(),
config.HttpConfig.AuthKeysLocation,
@@ -262,24 +249,26 @@ var (
KeysLocation: config.HttpConfig.AuthKeysLocation,
}
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}
ephemeralManager := server.NewEphemeralManager(store, accountManager)
ephemeralManager.LoadInitialPeers(ctx)
ephemeralManager.LoadInitialPeers()
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(ctx, config, accountManager, peersUpdateManager, turnManager, appMetrics, ephemeralManager)
srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, appMetrics, ephemeralManager)
if err != nil {
return fmt.Errorf("failed creating gRPC API handler: %v", err)
}
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
installationID, err := getInstallationID(ctx, store)
installationID, err := getInstallationID(store)
if err != nil {
log.WithContext(ctx).Errorf("cannot load TLS credentials: %v", err)
log.Errorf("cannot load TLS credentials: %v", err)
return err
}
@@ -289,18 +278,18 @@ var (
idpManager = config.IdpManagerConfig.ManagerType
}
metricsWorker := metrics.NewWorker(ctx, installationID, store, peersUpdateManager, idpManager)
go metricsWorker.Run(ctx)
go metricsWorker.Run()
}
var compatListener net.Listener
if mgmtPort != ManagementLegacyPort {
// The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it
// are using port 33073. For compatibility purposes we keep running a 2nd gRPC server on port 33073.
compatListener, err = serveGRPC(ctx, gRPCAPIHandler, ManagementLegacyPort)
compatListener, err = serveGRPC(gRPCAPIHandler, ManagementLegacyPort)
if err != nil {
return err
}
log.WithContext(ctx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
log.Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
}
rootHandler := handlerFunc(gRPCAPIHandler, httpAPIHandler)
@@ -317,8 +306,8 @@ var (
if err != nil {
return fmt.Errorf("failed creating TLS listener on port %d: %v", mgmtPort, err)
}
log.WithContext(ctx).Infof("running HTTP server (LetsEncrypt challenge handler): %s", cml.Addr().String())
serveHTTP(ctx, cml, certManager.HTTPHandler(nil))
log.Infof("running HTTP server (LetsEncrypt challenge handler): %s", cml.Addr().String())
serveHTTP(cml, certManager.HTTPHandler(nil))
}
} else if tlsConfig != nil {
listener, err = tls.Listen("tcp", fmt.Sprintf(":%d", mgmtPort), tlsConfig)
@@ -332,14 +321,14 @@ var (
}
}
log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion())
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String())
serveGRPCWithHTTP(ctx, listener, rootHandler, tlsEnabled)
log.Infof("management server version %s", version.NetbirdVersion())
log.Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String())
serveGRPCWithHTTP(listener, rootHandler, tlsEnabled)
SetupCloseHandler()
<-stopCh
integratedPeerValidator.Stop(ctx)
integratedPeerValidator.Stop()
if geo != nil {
_ = geo.Stop()
}
@@ -350,68 +339,39 @@ var (
_ = certManager.Listener().Close()
}
gRPCAPIHandler.Stop()
_ = store.Close(ctx)
_ = eventStore.Close(ctx)
log.WithContext(ctx).Infof("stopped Management Service")
_ = store.Close()
_ = eventStore.Close()
log.Infof("stopped Management Service")
return nil
},
}
)
func unaryInterceptor(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
reqID := uuid.New().String()
//nolint
ctx = context.WithValue(ctx, formatter.ExecutionContextKey, formatter.GRPCSource)
//nolint
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
return handler(ctx, req)
}
func streamInterceptor(
srv interface{},
ss grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
reqID := uuid.New().String()
wrapped := grpcMiddleware.WrapServerStream(ss)
//nolint
ctx := context.WithValue(ss.Context(), formatter.ExecutionContextKey, formatter.GRPCSource)
//nolint
wrapped.WrappedContext = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
return handler(srv, wrapped)
}
func notifyStop(ctx context.Context, msg string) {
func notifyStop(msg string) {
select {
case stopCh <- 1:
log.WithContext(ctx).Error(msg)
log.Error(msg)
default:
// stop has been already called, nothing to report
}
}
func getInstallationID(ctx context.Context, store server.Store) (string, error) {
func getInstallationID(store server.Store) (string, error) {
installationID := store.GetInstallationID()
if installationID != "" {
return installationID, nil
}
installationID = strings.ToUpper(uuid.New().String())
err := store.SaveInstallationID(ctx, installationID)
err := store.SaveInstallationID(installationID)
if err != nil {
return "", err
}
return installationID, nil
}
func serveGRPC(ctx context.Context, grpcServer *grpc.Server, port int) (net.Listener, error) {
func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
return nil, err
@@ -419,22 +379,22 @@ func serveGRPC(ctx context.Context, grpcServer *grpc.Server, port int) (net.List
go func() {
err := grpcServer.Serve(listener)
if err != nil {
notifyStop(ctx, fmt.Sprintf("failed running gRPC server on port %d: %v", port, err))
notifyStop(fmt.Sprintf("failed running gRPC server on port %d: %v", port, err))
}
}()
return listener, nil
}
func serveHTTP(ctx context.Context, httpListener net.Listener, handler http.Handler) {
func serveHTTP(httpListener net.Listener, handler http.Handler) {
go func() {
err := http.Serve(httpListener, handler)
if err != nil {
notifyStop(ctx, fmt.Sprintf("failed running HTTP server: %v", err))
notifyStop(fmt.Sprintf("failed running HTTP server: %v", err))
}
}()
}
func serveGRPCWithHTTP(ctx context.Context, listener net.Listener, handler http.Handler, tlsEnabled bool) {
func serveGRPCWithHTTP(listener net.Listener, handler http.Handler, tlsEnabled bool) {
go func() {
var err error
if tlsEnabled {
@@ -451,7 +411,7 @@ func serveGRPCWithHTTP(ctx context.Context, listener net.Listener, handler http.
if err != nil {
select {
case stopCh <- 1:
log.WithContext(ctx).Errorf("failed to serve HTTP and gRPC server: %v", err)
log.Errorf("failed to serve HTTP and gRPC server: %v", err)
default:
// stop has been already called, nothing to report
}
@@ -471,7 +431,7 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle
})
}
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) {
func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
loadedConfig := &server.Config{}
_, err := util.ReadJson(mgmtConfigPath, loadedConfig)
if err != nil {
@@ -492,26 +452,26 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config,
oidcEndpoint := loadedConfig.HttpConfig.OIDCConfigEndpoint
if oidcEndpoint != "" {
// if OIDCConfigEndpoint is specified, we can load DeviceAuthEndpoint and TokenEndpoint automatically
log.WithContext(ctx).Infof("loading OIDC configuration from the provided IDP configuration endpoint %s", oidcEndpoint)
oidcConfig, err := fetchOIDCConfig(ctx, oidcEndpoint)
log.Infof("loading OIDC configuration from the provided IDP configuration endpoint %s", oidcEndpoint)
oidcConfig, err := fetchOIDCConfig(oidcEndpoint)
if err != nil {
return nil, err
}
log.WithContext(ctx).Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint)
log.Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint)
log.WithContext(ctx).Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s",
log.Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s",
oidcConfig.Issuer, loadedConfig.HttpConfig.AuthIssuer)
loadedConfig.HttpConfig.AuthIssuer = oidcConfig.Issuer
log.WithContext(ctx).Infof("overriding HttpConfig.AuthKeysLocation (JWT certs) with a new value %s, previously configured value: %s",
log.Infof("overriding HttpConfig.AuthKeysLocation (JWT certs) with a new value %s, previously configured value: %s",
oidcConfig.JwksURI, loadedConfig.HttpConfig.AuthKeysLocation)
loadedConfig.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(server.NONE)) {
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
log.Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.TokenEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint)
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.DeviceAuthEndpoint with a new value: %s, previously configured value: %s",
log.Infof("overriding DeviceAuthorizationFlow.DeviceAuthEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.DeviceAuthEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint)
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint = oidcConfig.DeviceAuthEndpoint
@@ -519,7 +479,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config,
if err != nil {
return nil, err
}
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s",
log.Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s",
u.Host, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain)
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host
@@ -529,10 +489,10 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config,
}
if loadedConfig.PKCEAuthorizationFlow != nil {
log.WithContext(ctx).Infof("overriding PKCEAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
log.Infof("overriding PKCEAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.TokenEndpoint, loadedConfig.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint)
loadedConfig.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
log.WithContext(ctx).Infof("overriding PKCEAuthorizationFlow.AuthorizationEndpoint with a new value: %s, previously configured value: %s",
log.Infof("overriding PKCEAuthorizationFlow.AuthorizationEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.AuthorizationEndpoint, loadedConfig.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint)
loadedConfig.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint = oidcConfig.AuthorizationEndpoint
}
@@ -541,8 +501,8 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config,
return loadedConfig, err
}
func updateMgmtConfig(ctx context.Context, path string, config *server.Config) error {
return util.DirectWriteJson(ctx, path, config)
func updateMgmtConfig(path string, config *server.Config) error {
return util.DirectWriteJson(path, config)
}
// OIDCConfigResponse used for parsing OIDC config response
@@ -555,7 +515,7 @@ type OIDCConfigResponse struct {
}
// fetchOIDCConfig fetches OIDC configuration from the IDP
func fetchOIDCConfig(ctx context.Context, oidcEndpoint string) (OIDCConfigResponse, error) {
func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) {
res, err := http.Get(oidcEndpoint)
if err != nil {
return OIDCConfigResponse{}, fmt.Errorf("failed fetching OIDC configuration from endpoint %s %v", oidcEndpoint, err)
@@ -564,7 +524,7 @@ func fetchOIDCConfig(ctx context.Context, oidcEndpoint string) (OIDCConfigRespon
defer func() {
err := res.Body.Close()
if err != nil {
log.WithContext(ctx).Debugf("failed closing response body %v", err)
log.Debugf("failed closing response body %v", err)
}
}()

View File

@@ -1,16 +1,13 @@
package cmd
import (
"context"
"flag"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/util"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
)
var shortUp = "Migrate JSON file store to SQLite store. Please make a backup of the JSON file before running this command."
@@ -29,13 +26,10 @@ var upCmd = &cobra.Command{
return fmt.Errorf("failed initializing log %v", err)
}
//nolint
ctx := context.WithValue(cmd.Context(), formatter.ExecutionContextKey, formatter.SystemSource)
if err := server.MigrateFileStoreToSqlite(ctx, mgmtDataDir); err != nil {
if err := server.MigrateFileStoreToSqlite(mgmtDataDir); err != nil {
return err
}
log.WithContext(ctx).Info("Migration finished successfully")
log.Info("Migration finished successfully")
return nil
},

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,6 @@
package server
import (
"context"
"crypto/sha256"
b64 "encoding/base64"
"encoding/json"
@@ -30,11 +29,11 @@ import (
type MocIntegratedValidator struct {
}
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
func (a MocIntegratedValidator) ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
return nil
}
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
func (a MocIntegratedValidator) ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
return update, nil
}
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
@@ -45,15 +44,15 @@ func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[s
return validatedPeers, nil
}
func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
return peer
}
func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
return false, false, nil
}
func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
func (MocIntegratedValidator) PeerDeleted(_, _ string) error {
return nil
}
@@ -61,7 +60,7 @@ func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)
}
func (MocIntegratedValidator) Stop(_ context.Context) {
func (MocIntegratedValidator) Stop() {
}
func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) {
@@ -86,7 +85,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Ac
setupKey = key.Key
}
_, _, _, err := manager.AddPeer(context.Background(), setupKey, userID, peer)
_, _, _, err := manager.AddPeer(setupKey, userID, peer)
if err != nil {
t.Error("expected to add new peer successfully after creating new account, but failed", err)
}
@@ -396,7 +395,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
}
for _, testCase := range tt {
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io")
account := newAccountWithId("account-1", userID, "netbird.io")
account.UpdateSettings(&testCase.accountSettings)
account.Network = network
account.Peers = testCase.peers
@@ -410,7 +409,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
validatedPeers[p] = struct{}{}
}
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, "netbird.io", validatedPeers)
networkMap := account.GetPeerNetworkMap(testCase.peerID, "netbird.io", validatedPeers)
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
}
@@ -420,7 +419,7 @@ func TestNewAccount(t *testing.T) {
domain := "netbird.io"
userId := "account_creator"
accountID := "account_id"
account := newAccountWithId(context.Background(), accountID, userId, domain)
account := newAccountWithId(accountID, userId, domain)
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
}
@@ -431,7 +430,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
return
}
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
account, err := manager.GetOrCreateAccountByUser(userID, "")
if err != nil {
t.Fatal(err)
}
@@ -440,7 +439,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
return
}
account, err = manager.Store.GetAccountByUser(context.Background(), userID)
account, err = manager.Store.GetAccountByUser(userID)
if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userID)
return
@@ -631,11 +630,11 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
initAccount, err := manager.GetAccountByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
initAccount, err := manager.GetAccountByUserOrAccountID(testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
require.NoError(t, err, "create init user failed")
if testCase.inputUpdateAttrs {
err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
err = manager.updateAccountDomainAttributes(initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
require.NoError(t, err, "update init user failed")
}
@@ -643,7 +642,7 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
testCase.inputClaims.AccountId = initAccount.Id
}
account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims)
account, _, err := manager.GetAccountFromToken(testCase.inputClaims)
require.NoError(t, err, "support function failed")
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
@@ -662,12 +661,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
userId := "user-id"
domain := "test.domain"
initAccount := newAccountWithId(context.Background(), "", userId, domain)
initAccount := newAccountWithId("", userId, domain)
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID := initAccount.Id
acc, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, accountID, domain)
acc, err := manager.GetAccountByUserOrAccountID(userId, accountID, domain)
require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization
// that happens inside the GetAccountByUserOrAccountID where the id is getting generated
@@ -683,18 +682,18 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
}
t.Run("JWT groups disabled", func(t *testing.T) {
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
account, _, err := manager.GetAccountFromToken(claims)
require.NoError(t, err, "get account by token failed")
require.Len(t, account.Groups, 1, "only ALL group should exists")
})
t.Run("JWT groups enabled without claim name", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true
err := manager.Store.SaveAccount(context.Background(), initAccount)
err := manager.Store.SaveAccount(initAccount)
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
require.Len(t, manager.Store.GetAllAccounts(), 1, "only one account should exist")
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
account, _, err := manager.GetAccountFromToken(claims)
require.NoError(t, err, "get account by token failed")
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
})
@@ -702,11 +701,11 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
t.Run("JWT groups enabled", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true
initAccount.Settings.JWTGroupsClaimName = "idp-groups"
err := manager.Store.SaveAccount(context.Background(), initAccount)
err := manager.Store.SaveAccount(initAccount)
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
require.Len(t, manager.Store.GetAllAccounts(), 1, "only one account should exist")
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
account, _, err := manager.GetAccountFromToken(claims)
require.NoError(t, err, "get account by token failed")
require.Len(t, account.Groups, 3, "groups should be added to the account")
@@ -729,7 +728,7 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
func TestAccountManager_GetAccountFromPAT(t *testing.T) {
store := newStore(t)
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
account := newAccountWithId("account_id", "testuser", "")
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token))
@@ -743,7 +742,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
},
},
}
err := store.SaveAccount(context.Background(), account)
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@@ -752,7 +751,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
Store: store,
}
account, user, pat, err := am.GetAccountFromPAT(context.Background(), token)
account, user, pat, err := am.GetAccountFromPAT(token)
if err != nil {
t.Fatalf("Error when getting Account from PAT: %s", err)
}
@@ -764,7 +763,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
store := newStore(t)
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
account := newAccountWithId("account_id", "testuser", "")
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token))
@@ -779,7 +778,7 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
},
},
}
err := store.SaveAccount(context.Background(), account)
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@@ -788,12 +787,12 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
Store: store,
}
err = am.MarkPATUsed(context.Background(), "tokenId")
err = am.MarkPATUsed("tokenId")
if err != nil {
t.Fatalf("Error when marking PAT used: %s", err)
}
account, err = am.Store.GetAccount(context.Background(), "account_id")
account, err = am.Store.GetAccount("account_id")
if err != nil {
t.Fatalf("Error when getting account: %s", err)
}
@@ -808,7 +807,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
}
userId := "test_user"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, "")
account, err := manager.GetOrCreateAccountByUser(userId, "")
if err != nil {
t.Fatal(err)
}
@@ -816,7 +815,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
t.Fatalf("expected to create an account for a user %s", userId)
}
account, err = manager.Store.GetAccountByUser(context.Background(), userId)
account, err = manager.Store.GetAccountByUser(userId)
if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId)
}
@@ -835,7 +834,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
userId := "test_user"
domain := "hotmail.com"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, domain)
account, err := manager.GetOrCreateAccountByUser(userId, domain)
if err != nil {
t.Fatal(err)
}
@@ -849,7 +848,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
domain = "gmail.com"
account, err = manager.GetOrCreateAccountByUser(context.Background(), userId, domain)
account, err = manager.GetOrCreateAccountByUser(userId, domain)
if err != nil {
t.Fatalf("got the following error while retrieving existing acc: %v", err)
}
@@ -872,7 +871,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
userId := "test_user"
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, "", "")
account, err := manager.GetAccountByUserOrAccountID(userId, "", "")
if err != nil {
t.Fatal(err)
}
@@ -881,20 +880,20 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
return
}
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
_, err = manager.GetAccountByUserOrAccountID("", account.Id, "")
if err != nil {
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id)
}
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", "", "")
_, err = manager.GetAccountByUserOrAccountID("", "", "")
if err == nil {
t.Errorf("expected an error when user and account IDs are empty")
}
}
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) {
account := newAccountWithId(context.Background(), accountID, userID, domain)
err := am.Store.SaveAccount(context.Background(), account)
account := newAccountWithId(accountID, userID, domain)
err := am.Store.SaveAccount(account)
if err != nil {
return nil, err
}
@@ -916,7 +915,7 @@ func TestAccountManager_GetAccount(t *testing.T) {
}
// AddAccount has been already tested so we can assume it is correct and compare results
getAccount, err := manager.Store.GetAccount(context.Background(), account.Id)
getAccount, err := manager.Store.GetAccount(account.Id)
if err != nil {
t.Fatal(err)
return
@@ -953,12 +952,12 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
t.Fatal(err)
}
err = manager.DeleteAccount(context.Background(), account.Id, userId)
err = manager.DeleteAccount(account.Id, userId)
if err != nil {
t.Fatal(err)
}
getAccount, err := manager.Store.GetAccount(context.Background(), account.Id)
getAccount, err := manager.Store.GetAccount(account.Id)
if err == nil {
t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount))
}
@@ -979,7 +978,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
serial := account.Network.CurrentSerial() // should be 0
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
if err != nil {
t.Fatal("error creating setup key")
return
@@ -998,7 +997,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
expectedPeerKey := key.PublicKey().String()
expectedSetupKey := setupKey.Key
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
@@ -1007,7 +1006,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
return
}
account, err = manager.Store.GetAccount(context.Background(), account.Id)
account, err = manager.Store.GetAccount(account.Id)
if err != nil {
t.Fatal(err)
return
@@ -1046,7 +1045,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
return
}
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "netbird.cloud")
account, err := manager.GetOrCreateAccountByUser(userID, "netbird.cloud")
if err != nil {
t.Fatal(err)
}
@@ -1066,7 +1065,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
expectedPeerKey := key.PublicKey().String()
expectedUserID := userID
peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
peer, _, _, err := manager.AddPeer("", userID, &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
@@ -1075,7 +1074,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
return
}
account, err = manager.Store.GetAccount(context.Background(), account.Id)
account, err = manager.Store.GetAccount(account.Id)
if err != nil {
t.Fatal(err)
return
@@ -1122,7 +1121,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
t.Fatal(err)
}
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
if err != nil {
t.Fatal("error creating setup key")
return
@@ -1141,7 +1140,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
}
expectedPeerKey := key.PublicKey().String()
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
@@ -1157,14 +1156,14 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
peer2 := getPeer()
peer3 := getPeer()
account, err = manager.Store.GetAccount(context.Background(), account.Id)
account, err = manager.Store.GetAccount(account.Id)
if err != nil {
t.Fatal(err)
return
}
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
updMsg := manager.peersUpdateManager.CreateChannel(peer1.ID)
defer manager.peersUpdateManager.CloseChannel(peer1.ID)
group := group.Group{
ID: "group-id",
@@ -1198,7 +1197,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
}
}()
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
if err := manager.SaveGroup(account.Id, userID, &group); err != nil {
t.Errorf("save group: %v", err)
return
}
@@ -1218,7 +1217,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
}
}()
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
if err := manager.DeletePolicy(account.Id, account.Policies[0].ID, userID); err != nil {
t.Errorf("delete default rule: %v", err)
return
}
@@ -1238,7 +1237,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
}
}()
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil {
if err := manager.SavePolicy(account.Id, userID, &policy); err != nil {
t.Errorf("delete default rule: %v", err)
return
}
@@ -1257,7 +1256,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
}
}()
if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil {
if err := manager.DeletePeer(account.Id, peer3.ID, userID); err != nil {
t.Errorf("delete peer: %v", err)
return
}
@@ -1278,9 +1277,9 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
}()
// clean policy is pre requirement for delete group
_ = manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID)
_ = manager.DeletePolicy(account.Id, policy.ID, userID)
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
if err := manager.DeleteGroup(account.Id, "", group.ID); err != nil {
t.Errorf("delete group: %v", err)
return
}
@@ -1302,7 +1301,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
t.Fatal(err)
}
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
if err != nil {
t.Fatal("error creating setup key")
return
@@ -1316,7 +1315,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
peerKey := key.PublicKey().String()
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
Key: peerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: peerKey},
})
@@ -1325,12 +1324,12 @@ func TestAccountManager_DeletePeer(t *testing.T) {
return
}
err = manager.DeletePeer(context.Background(), account.Id, peerKey, userID)
err = manager.DeletePeer(account.Id, peerKey, userID)
if err != nil {
return
}
account, err = manager.Store.GetAccount(context.Background(), account.Id)
account, err = manager.Store.GetAccount(account.Id)
if err != nil {
t.Fatal(err)
return
@@ -1358,7 +1357,7 @@ func getEvent(t *testing.T, accountID string, manager AccountManager, eventType
case <-time.After(time.Second):
t.Fatal("no PeerAddedWithSetupKey event was generated")
default:
events, err := manager.GetEvents(context.Background(), accountID, userID)
events, err := manager.GetEvents(accountID, userID)
if err != nil {
t.Fatal(err)
}
@@ -1390,7 +1389,7 @@ func TestGetUsersFromAccount(t *testing.T) {
account.Users[user.Id] = user
}
userInfos, err := manager.GetUsersFromAccount(context.Background(), accountId, "1")
userInfos, err := manager.GetUsersFromAccount(accountId, "1")
if err != nil {
t.Fatal(err)
}
@@ -1501,7 +1500,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
},
}
routes := account.getRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
routes := account.getRoutesToSync("peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
assert.Len(t, routes, 2)
routeIDs := make(map[route.ID]struct{}, 2)
@@ -1511,7 +1510,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
assert.Contains(t, routeIDs, route.ID("route-2"))
assert.Contains(t, routeIDs, route.ID("route-3"))
emptyRoutes := account.getRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
emptyRoutes := account.getRoutesToSync("peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
assert.Len(t, emptyRoutes, 0)
}
@@ -1646,7 +1645,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to create an account")
assert.NotNil(t, account.Settings)
@@ -1658,23 +1657,23 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
peer, _, _, err := manager.AddPeer("", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
})
require.NoError(t, err, "unable to add peer")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
@@ -1683,10 +1682,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(2)
manager.peerLoginExpiry = &MockScheduler{
CancelFunc: func(ctx context.Context, IDs []string) {
CancelFunc: func(IDs []string) {
wg.Done()
},
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
wg.Done()
},
}
@@ -1694,11 +1693,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
// disable expiration first
update := peer.Copy()
update.LoginExpirationEnabled = false
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
_, err = manager.UpdatePeer(account.Id, userID, update)
require.NoError(t, err, "unable to update peer")
// enabling expiration should trigger the routine
update.LoginExpirationEnabled = true
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
_, err = manager.UpdatePeer(account.Id, userID, update)
require.NoError(t, err, "unable to update peer")
failed := waitTimeout(wg, time.Second)
@@ -1711,18 +1710,18 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
_, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
_, _, _, err = manager.AddPeer("", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
})
require.NoError(t, err, "unable to add peer")
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
@@ -1731,18 +1730,18 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
wg := &sync.WaitGroup{}
wg.Add(2)
manager.peerLoginExpiry = &MockScheduler{
CancelFunc: func(ctx context.Context, IDs []string) {
CancelFunc: func(IDs []string) {
wg.Done()
},
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
wg.Done()
},
}
account, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
account, err = manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second)
@@ -1755,35 +1754,35 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
_, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
_, _, _, err = manager.AddPeer("", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
})
require.NoError(t, err, "unable to add peer")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{}
wg.Add(2)
manager.peerLoginExpiry = &MockScheduler{
CancelFunc: func(ctx context.Context, IDs []string) {
CancelFunc: func(IDs []string) {
wg.Done()
},
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
wg.Done()
},
}
// enabling PeerLoginExpirationEnabled should trigger the expiration job
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
@@ -1796,7 +1795,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
wg.Add(1)
// disabling PeerLoginExpirationEnabled should trigger cancel
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false,
})
@@ -1811,10 +1810,10 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to create an account")
updated, err := manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
updated, err := manager.UpdateAccountSettings(account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false,
})
@@ -1822,19 +1821,19 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
account, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
account, err = manager.GetAccountByUserOrAccountID("", account.Id, "")
require.NoError(t, err, "unable to get account by ID")
assert.False(t, account.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, account.Settings.PeerLoginExpiration, time.Hour)
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
PeerLoginExpiration: time.Second,
PeerLoginExpirationEnabled: false,
})
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour * 24 * 181,
PeerLoginExpirationEnabled: false,
})
@@ -2295,7 +2294,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
}
eventStore := &activity.InMemoryEventStore{}
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
manager, err := BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
if err != nil {
return nil, err
}
@@ -2306,7 +2305,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
func createStore(t *testing.T) (Store, error) {
t.Helper()
dataDir := t.TempDir()
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
store, cleanUp, err := NewTestStoreFromJson(dataDir)
if err != nil {
return nil, err
}

View File

@@ -1,7 +1,6 @@
package sqlite
import (
"context"
"database/sql"
"encoding/json"
"fmt"
@@ -87,7 +86,7 @@ type Store struct {
}
// NewSQLiteStore creates a new Store with an event table if not exists.
func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (*Store, error) {
func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) {
dbFile := filepath.Join(dataDir, eventSinkDB)
db, err := sql.Open("sqlite3", dbFile)
if err != nil {
@@ -112,7 +111,7 @@ func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (
return nil, err
}
err = updateDeletedUsersTable(ctx, db)
err = updateDeletedUsersTable(db)
if err != nil {
_ = db.Close()
return nil, err
@@ -154,7 +153,7 @@ func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (
return s, nil
}
func (store *Store) processResult(ctx context.Context, result *sql.Rows) ([]*activity.Event, error) {
func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
events := make([]*activity.Event, 0)
var cryptErr error
for result.Next() {
@@ -236,14 +235,14 @@ func (store *Store) processResult(ctx context.Context, result *sql.Rows) ([]*act
}
if cryptErr != nil {
log.WithContext(ctx).Warnf("%s", cryptErr)
log.Warnf("%s", cryptErr)
}
return events, nil
}
// Get returns "limit" number of events from index ordered descending or ascending by a timestamp
func (store *Store) Get(ctx context.Context, accountID string, offset, limit int, descending bool) ([]*activity.Event, error) {
func (store *Store) Get(accountID string, offset, limit int, descending bool) ([]*activity.Event, error) {
stmt := store.selectDescStatement
if !descending {
stmt = store.selectAscStatement
@@ -255,11 +254,11 @@ func (store *Store) Get(ctx context.Context, accountID string, offset, limit int
}
defer result.Close() //nolint
return store.processResult(ctx, result)
return store.processResult(result)
}
// Save an event in the SQLite events table end encrypt the "email" element in meta map
func (store *Store) Save(_ context.Context, event *activity.Event) (*activity.Event, error) {
func (store *Store) Save(event *activity.Event) (*activity.Event, error) {
var jsonMeta string
meta, err := store.saveDeletedUserEmailAndNameInEncrypted(event)
if err != nil {
@@ -318,15 +317,15 @@ func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event
}
// Close the Store
func (store *Store) Close(_ context.Context) error {
func (store *Store) Close() error {
if store.db != nil {
return store.db.Close()
}
return nil
}
func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error {
log.WithContext(ctx).Debugf("check deleted_users table version")
func updateDeletedUsersTable(db *sql.DB) error {
log.Debugf("check deleted_users table version")
rows, err := db.Query(`PRAGMA table_info(deleted_users);`)
if err != nil {
return err
@@ -361,7 +360,7 @@ func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error {
return nil
}
log.WithContext(ctx).Debugf("update delted_users table")
log.Debugf("update delted_users table")
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
return err
}

View File

@@ -1,7 +1,6 @@
package sqlite
import (
"context"
"fmt"
"testing"
"time"
@@ -14,17 +13,17 @@ import (
func TestNewSQLiteStore(t *testing.T) {
dataDir := t.TempDir()
key, _ := GenerateKey()
store, err := NewSQLiteStore(context.Background(), dataDir, key)
store, err := NewSQLiteStore(dataDir, key)
if err != nil {
t.Fatal(err)
return
}
defer store.Close(context.Background()) //nolint
defer store.Close() //nolint
accountID := "account_1"
for i := 0; i < 10; i++ {
_, err = store.Save(context.Background(), &activity.Event{
_, err = store.Save(&activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "user_" + fmt.Sprint(i),
@@ -37,7 +36,7 @@ func TestNewSQLiteStore(t *testing.T) {
}
}
result, err := store.Get(context.Background(), accountID, 0, 10, false)
result, err := store.Get(accountID, 0, 10, false)
if err != nil {
t.Fatal(err)
return
@@ -46,7 +45,7 @@ func TestNewSQLiteStore(t *testing.T) {
assert.Len(t, result, 10)
assert.True(t, result[0].Timestamp.Before(result[len(result)-1].Timestamp))
result, err = store.Get(context.Background(), accountID, 0, 5, true)
result, err = store.Get(accountID, 0, 5, true)
if err != nil {
t.Fatal(err)
return

View File

@@ -1,18 +1,15 @@
package activity
import (
"context"
"sync"
)
import "sync"
// Store provides an interface to store or stream events.
type Store interface {
// Save an event in the store
Save(ctx context.Context, event *Event) (*Event, error)
Save(event *Event) (*Event, error)
// Get returns "limit" number of events from the "offset" index ordered descending or ascending by a timestamp
Get(ctx context.Context, accountID string, offset, limit int, descending bool) ([]*Event, error)
Get(accountID string, offset, limit int, descending bool) ([]*Event, error)
// Close the sink flushing events if necessary
Close(ctx context.Context) error
Close() error
}
// InMemoryEventStore implements the Store interface storing data in-memory
@@ -23,7 +20,7 @@ type InMemoryEventStore struct {
}
// Save sets the Event.ID to 1
func (store *InMemoryEventStore) Save(_ context.Context, event *Event) (*Event, error) {
func (store *InMemoryEventStore) Save(event *Event) (*Event, error) {
store.mu.Lock()
defer store.mu.Unlock()
if store.events == nil {
@@ -36,7 +33,7 @@ func (store *InMemoryEventStore) Save(_ context.Context, event *Event) (*Event,
}
// Get returns a list of ALL events that belong to the given accountID without taking offset, limit and order into consideration
func (store *InMemoryEventStore) Get(_ context.Context, accountID string, offset, limit int, descending bool) ([]*Event, error) {
func (store *InMemoryEventStore) Get(accountID string, offset, limit int, descending bool) ([]*Event, error) {
store.mu.Lock()
defer store.mu.Unlock()
events := make([]*Event, 0)
@@ -49,7 +46,7 @@ func (store *InMemoryEventStore) Get(_ context.Context, accountID string, offset
}
// Close cleans up the event list
func (store *InMemoryEventStore) Close(_ context.Context) error {
func (store *InMemoryEventStore) Close() error {
store.mu.Lock()
defer store.mu.Unlock()
store.events = make([]*Event, 0)

View File

@@ -1,8 +0,0 @@
package context
const (
RequestIDKey = "requestID"
AccountIDKey = "accountID"
UserIDKey = "userID"
PeerIDKey = "peerID"
)

View File

@@ -1,7 +1,6 @@
package server
import (
"context"
"fmt"
"strconv"
@@ -35,11 +34,11 @@ func (d DNSSettings) Copy() DNSSettings {
}
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string) (*DNSSettings, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
@@ -57,11 +56,11 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
}
// SaveDNSSettings validates a user role and updates the account's DNS settings
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
account, err := am.Store.GetAccount(accountID)
if err != nil {
return err
}
@@ -90,7 +89,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
account.DNSSettings = dnsSettingsToSave.Copy()
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
if err = am.Store.SaveAccount(account); err != nil {
return err
}
@@ -98,18 +97,17 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
for _, id := range addedGroups {
group := account.GetGroup(id)
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
am.StoreEvent(userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
}
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
for _, id := range removedGroups {
group := account.GetGroup(id)
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
am.StoreEvent(userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
}
// todo: check if before/after groups are in use by dns, acl, routes and if it has peers
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(account)
return nil
}
@@ -151,9 +149,9 @@ func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig {
return protoUpdate
}
func getPeersCustomZone(ctx context.Context, account *Account, dnsDomain string) nbdns.CustomZone {
func getPeersCustomZone(account *Account, dnsDomain string) nbdns.CustomZone {
if dnsDomain == "" {
log.WithContext(ctx).Errorf("no dns domain is set, returning empty zone")
log.Errorf("no dns domain is set, returning empty zone")
return nbdns.CustomZone{}
}
@@ -163,7 +161,7 @@ func getPeersCustomZone(ctx context.Context, account *Account, dnsDomain string)
for _, peer := range account.Peers {
if peer.DNSLabel == "" {
log.WithContext(ctx).Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name)
log.Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name)
continue
}
@@ -212,14 +210,14 @@ func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
return false
}
func addPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels lookupMap) {
func addPeerLabelsToAccount(account *Account, peerLabels lookupMap) {
for _, peer := range account.Peers {
label, err := getPeerHostLabel(peer.Name, peerLabels)
if err != nil {
log.WithContext(ctx).Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err)
log.Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err)
label, err = getPeerHostLabel(peer.Meta.Hostname, peerLabels)
if err != nil {
log.WithContext(ctx).Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err)
log.Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err)
continue
}
}

View File

@@ -1,7 +1,6 @@
package server
import (
"context"
"net/netip"
"testing"
@@ -36,7 +35,7 @@ func TestGetDNSSettings(t *testing.T) {
t.Fatal("failed to init testing account")
}
dnsSettings, err := am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID)
dnsSettings, err := am.GetDNSSettings(account.Id, dnsAdminUserID)
if err != nil {
t.Fatalf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err)
}
@@ -49,12 +48,12 @@ func TestGetDNSSettings(t *testing.T) {
DisabledManagementGroups: []string{group1ID},
}
err = am.Store.SaveAccount(context.Background(), account)
err = am.Store.SaveAccount(account)
if err != nil {
t.Error("failed to save testing account with new DNS settings")
}
dnsSettings, err = am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID)
dnsSettings, err = am.GetDNSSettings(account.Id, dnsAdminUserID)
if err != nil {
t.Errorf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err)
}
@@ -63,7 +62,7 @@ func TestGetDNSSettings(t *testing.T) {
t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups)
}
_, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID)
_, err = am.GetDNSSettings(account.Id, dnsRegularUserID)
if err == nil {
t.Errorf("An error should be returned when getting the DNS settings with a regular user")
}
@@ -123,7 +122,7 @@ func TestSaveDNSSettings(t *testing.T) {
t.Error("failed to init testing account")
}
err = am.SaveDNSSettings(context.Background(), account.Id, testCase.userID, testCase.inputSettings)
err = am.SaveDNSSettings(account.Id, testCase.userID, testCase.inputSettings)
if err != nil {
if testCase.shouldFail {
return
@@ -131,7 +130,7 @@ func TestSaveDNSSettings(t *testing.T) {
t.Error(err)
}
updatedAccount, err := am.Store.GetAccount(context.Background(), account.Id)
updatedAccount, err := am.Store.GetAccount(account.Id)
if err != nil {
t.Errorf("should be able to retrieve updated account, got err: %s", err)
}
@@ -165,7 +164,7 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
t.Error("failed to init testing account")
}
newAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID)
newAccountDNSConfig, err := am.GetNetworkMap(peer1.ID)
require.NoError(t, err)
require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS config should have one custom zone for peers")
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled")
@@ -174,14 +173,14 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
dnsSettings := account.DNSSettings.Copy()
dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID)
account.DNSSettings = dnsSettings
err = am.Store.SaveAccount(context.Background(), account)
err = am.Store.SaveAccount(account)
require.NoError(t, err)
updatedAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID)
updatedAccountDNSConfig, err := am.GetNetworkMap(peer1.ID)
require.NoError(t, err)
require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS config should have no custom zone when peer belongs to a disabled group")
require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS config should have local DNS service disabled when peer belongs to a disabled group")
peer2AccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer2.ID)
peer2AccountDNSConfig, err := am.GetNetworkMap(peer2.ID)
require.NoError(t, err)
require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS config should have one custom zone for peers not in the disabled group")
require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS config should have DNS service enabled for peers not in the disabled group")
@@ -195,13 +194,13 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err
}
eventStore := &activity.InMemoryEventStore{}
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{})
return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{})
}
func createDNSStore(t *testing.T) (Store, error) {
t.Helper()
dataDir := t.TempDir()
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
store, cleanUp, err := NewTestStoreFromJson(dataDir)
if err != nil {
return nil, err
}
@@ -245,28 +244,28 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
domain := "example.com"
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain)
account := newAccountWithId(dnsAccountID, dnsAdminUserID, domain)
account.Users[dnsRegularUserID] = &User{
Id: dnsRegularUserID,
Role: UserRoleUser,
}
err := am.Store.SaveAccount(context.Background(), account)
err := am.Store.SaveAccount(account)
if err != nil {
return nil, err
}
savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1)
savedPeer1, _, _, err := am.AddPeer("", dnsAdminUserID, peer1)
if err != nil {
return nil, err
}
_, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2)
_, _, _, err = am.AddPeer("", dnsAdminUserID, peer2)
if err != nil {
return nil, err
}
account, err = am.Store.GetAccount(context.Background(), account.Id)
account, err = am.Store.GetAccount(account.Id)
if err != nil {
return nil, err
}
@@ -313,10 +312,10 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
Groups: []string{allGroup.ID},
}
err = am.Store.SaveAccount(context.Background(), account)
err = am.Store.SaveAccount(account)
if err != nil {
return nil, err
}
return am.Store.GetAccount(context.Background(), account.Id)
return am.Store.GetAccount(account.Id)
}

View File

@@ -1,7 +1,6 @@
package server
import (
"context"
"sync"
"time"
@@ -52,15 +51,13 @@ func NewEphemeralManager(store Store, accountManager AccountManager) *EphemeralM
// LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head
// of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new
// head.
func (e *EphemeralManager) LoadInitialPeers(ctx context.Context) {
func (e *EphemeralManager) LoadInitialPeers() {
e.peersLock.Lock()
defer e.peersLock.Unlock()
e.loadEphemeralPeers(ctx)
e.loadEphemeralPeers()
if e.headPeer != nil {
e.timer = time.AfterFunc(ephemeralLifeTime, func() {
e.cleanup(ctx)
})
e.timer = time.AfterFunc(ephemeralLifeTime, e.cleanup)
}
}
@@ -76,12 +73,12 @@ func (e *EphemeralManager) Stop() {
// OnPeerConnected remove the peer from the linked list of ephemeral peers. Because it has been called when the peer
// is active the manager will not delete it while it is active.
func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Peer) {
func (e *EphemeralManager) OnPeerConnected(peer *nbpeer.Peer) {
if !peer.Ephemeral {
return
}
log.WithContext(ctx).Tracef("remove peer from ephemeral list: %s", peer.ID)
log.Tracef("remove peer from ephemeral list: %s", peer.ID)
e.peersLock.Lock()
defer e.peersLock.Unlock()
@@ -97,16 +94,16 @@ func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Pee
// OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer
// is inactive it will be deleted after the ephemeralLifeTime period.
func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) {
func (e *EphemeralManager) OnPeerDisconnected(peer *nbpeer.Peer) {
if !peer.Ephemeral {
return
}
log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID)
log.Tracef("add peer to ephemeral list: %s", peer.ID)
a, err := e.store.GetAccountByPeerID(context.Background(), peer.ID)
a, err := e.store.GetAccountByPeerID(peer.ID)
if err != nil {
log.WithContext(ctx).Errorf("failed to add peer to ephemeral list: %s", err)
log.Errorf("failed to add peer to ephemeral list: %s", err)
return
}
@@ -119,14 +116,12 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
e.addPeer(peer.ID, a, newDeadLine())
if e.timer == nil {
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() {
e.cleanup(ctx)
})
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), e.cleanup)
}
}
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
accounts := e.store.GetAllAccounts(context.Background())
func (e *EphemeralManager) loadEphemeralPeers() {
accounts := e.store.GetAllAccounts()
t := newDeadLine()
count := 0
for _, a := range accounts {
@@ -137,10 +132,10 @@ func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
}
}
}
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count)
log.Debugf("loaded ephemeral peer(s): %d", count)
}
func (e *EphemeralManager) cleanup(ctx context.Context) {
func (e *EphemeralManager) cleanup() {
log.Tracef("on ephemeral cleanup")
deletePeers := make(map[string]*ephemeralPeer)
@@ -159,9 +154,7 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
}
if e.headPeer != nil {
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() {
e.cleanup(ctx)
})
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), e.cleanup)
} else {
e.timer = nil
}
@@ -169,10 +162,10 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
e.peersLock.Unlock()
for id, p := range deletePeers {
log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id)
err := e.accountManager.DeletePeer(ctx, p.account.Id, id, activity.SystemInitiator)
log.Debugf("delete ephemeral peer: %s", id)
err := e.accountManager.DeletePeer(p.account.Id, id, activity.SystemInitiator)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
log.Errorf("failed to delete ephemeral peer: %s", err)
}
}
}

View File

@@ -1,7 +1,6 @@
package server
import (
"context"
"fmt"
"testing"
"time"
@@ -14,11 +13,11 @@ type MockStore struct {
account *Account
}
func (s *MockStore) GetAllAccounts(_ context.Context) []*Account {
func (s *MockStore) GetAllAccounts() []*Account {
return []*Account{s.account}
}
func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Account, error) {
func (s *MockStore) GetAccountByPeerID(peerId string) (*Account, error) {
_, ok := s.account.Peers[peerId]
if ok {
return s.account, nil
@@ -32,7 +31,7 @@ type MocAccountManager struct {
store *MockStore
}
func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error {
func (a MocAccountManager) DeletePeer(accountID, peerID, userID string) error {
delete(a.store.account.Peers, peerID)
return nil //nolint:nil
}
@@ -53,9 +52,9 @@ func TestNewManager(t *testing.T) {
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
mgr := NewEphemeralManager(store, am)
mgr.loadEphemeralPeers(context.Background())
mgr.loadEphemeralPeers()
startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup(context.Background())
mgr.cleanup()
if len(store.account.Peers) != numberOfPeers {
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers))
@@ -78,11 +77,11 @@ func TestNewManagerPeerConnected(t *testing.T) {
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
mgr := NewEphemeralManager(store, am)
mgr.loadEphemeralPeers(context.Background())
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
mgr.loadEphemeralPeers()
mgr.OnPeerConnected(store.account.Peers["ephemeral_peer_0"])
startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup(context.Background())
mgr.cleanup()
expected := numberOfPeers + 1
if len(store.account.Peers) != expected {
@@ -106,15 +105,15 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
mgr := NewEphemeralManager(store, am)
mgr.loadEphemeralPeers(context.Background())
mgr.loadEphemeralPeers()
for _, v := range store.account.Peers {
mgr.OnPeerConnected(context.Background(), v)
mgr.OnPeerConnected(v)
}
mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
mgr.OnPeerDisconnected(store.account.Peers["ephemeral_peer_0"])
startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup(context.Background())
mgr.cleanup()
expected := numberOfPeers + numberOfEphemeralPeers - 1
if len(store.account.Peers) != expected {
@@ -123,7 +122,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
}
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) {
store.account = newAccountWithId(context.Background(), "my account", "", "")
store.account = newAccountWithId("my account", "", "")
for i := 0; i < numberOfPeers; i++ {
peerId := fmt.Sprintf("peer_%d", i)

View File

@@ -1,7 +1,6 @@
package server
import (
"context"
"fmt"
"time"
@@ -12,11 +11,11 @@ import (
)
// GetEvents returns a list of activity events of an account
func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activity.Event, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
@@ -30,7 +29,7 @@ func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userI
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view events")
}
events, err := am.eventStore.Get(ctx, accountID, 0, 10000, true)
events, err := am.eventStore.Get(accountID, 0, 10000, true)
if err != nil {
return nil, err
}
@@ -55,10 +54,10 @@ func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userI
return filtered, nil
}
func (am *DefaultAccountManager) StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
func (am *DefaultAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
go func() {
_, err := am.eventStore.Save(ctx, &activity.Event{
_, err := am.eventStore.Save(&activity.Event{
Timestamp: time.Now().UTC(),
Activity: activityID,
InitiatorID: initiatorID,
@@ -68,7 +67,7 @@ func (am *DefaultAccountManager) StoreEvent(ctx context.Context, initiatorID, ta
})
if err != nil {
// todo add metric
log.WithContext(ctx).Errorf("received an error while storing an activity event, error: %s", err)
log.Errorf("received an error while storing an activity event, error: %s", err)
}
}()

View File

@@ -1,7 +1,6 @@
package server
import (
"context"
"testing"
"time"
@@ -14,7 +13,7 @@ func generateAndStoreEvents(t *testing.T, manager *DefaultAccountManager, typ ac
accountID string, count int) {
t.Helper()
for i := 0; i < count; i++ {
_, err := manager.eventStore.Save(context.Background(), &activity.Event{
_, err := manager.eventStore.Save(&activity.Event{
Timestamp: time.Now().UTC(),
Activity: typ,
InitiatorID: initiatorID,
@@ -36,32 +35,32 @@ func TestDefaultAccountManager_GetEvents(t *testing.T) {
accountID := "accountID"
t.Run("get empty events list", func(t *testing.T) {
events, err := manager.GetEvents(context.Background(), accountID, userID)
events, err := manager.GetEvents(accountID, userID)
if err != nil {
return
}
assert.Len(t, events, 0)
_ = manager.eventStore.Close(context.Background()) //nolint
_ = manager.eventStore.Close() //nolint
})
t.Run("get events", func(t *testing.T) {
generateAndStoreEvents(t, manager, activity.PeerAddedByUser, userID, "peer", accountID, 10)
events, err := manager.GetEvents(context.Background(), accountID, userID)
events, err := manager.GetEvents(accountID, userID)
if err != nil {
return
}
assert.Len(t, events, 10)
_ = manager.eventStore.Close(context.Background()) //nolint
_ = manager.eventStore.Close() //nolint
})
t.Run("get events without duplicates", func(t *testing.T) {
generateAndStoreEvents(t, manager, activity.UserJoined, userID, "", accountID, 10)
events, err := manager.GetEvents(context.Background(), accountID, userID)
events, err := manager.GetEvents(accountID, userID)
if err != nil {
return
}
assert.Len(t, events, 1)
_ = manager.eventStore.Close(context.Background()) //nolint
_ = manager.eventStore.Close() //nolint
})
}

View File

@@ -1,7 +1,6 @@
package server
import (
"context"
"os"
"path/filepath"
"strings"
@@ -49,8 +48,8 @@ type FileStore struct {
type StoredAccount struct{}
// NewFileStore restores a store from the file located in the datadir
func NewFileStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
fs, err := restore(ctx, filepath.Join(dataDir, storeFileName))
func NewFileStore(dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
fs, err := restore(filepath.Join(dataDir, storeFileName))
if err != nil {
return nil, err
}
@@ -59,27 +58,27 @@ func NewFileStore(ctx context.Context, dataDir string, metrics telemetry.AppMetr
}
// NewFilestoreFromSqliteStore restores a store from Sqlite and stores to Filestore json in the file located in datadir
func NewFilestoreFromSqliteStore(ctx context.Context, sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
store, err := NewFileStore(ctx, dataDir, metrics)
func NewFilestoreFromSqliteStore(sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
store, err := NewFileStore(dataDir, metrics)
if err != nil {
return nil, err
}
err = store.SaveInstallationID(ctx, sqlStore.GetInstallationID())
err = store.SaveInstallationID(sqlStore.GetInstallationID())
if err != nil {
return nil, err
}
for _, account := range sqlStore.GetAllAccounts(ctx) {
for _, account := range sqlStore.GetAllAccounts() {
store.Accounts[account.Id] = account
}
return store, store.persist(ctx, store.storeFile)
return store, store.persist(store.storeFile)
}
// restore the state of the store from the file.
// Creates a new empty store file if doesn't exist
func restore(ctx context.Context, file string) (*FileStore, error) {
func restore(file string) (*FileStore, error) {
if _, err := os.Stat(file); os.IsNotExist(err) {
// create a new FileStore if previously didn't exist (e.g. first run)
s := &FileStore{
@@ -96,7 +95,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
storeFile: file,
}
err = s.persist(ctx, file)
err = s.persist(file)
if err != nil {
return nil, err
}
@@ -166,7 +165,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
// for data migration. Can be removed once most base will be with labels
existingLabels := account.getPeerDNSLabels()
if len(existingLabels) != len(account.Peers) {
addPeerLabelsToAccount(ctx, account, existingLabels)
addPeerLabelsToAccount(account, existingLabels)
}
// TODO: delete this block after migration
@@ -179,7 +178,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
allGroup, err := account.GetGroupAll()
if err != nil {
log.WithContext(ctx).Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err)
log.Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err)
// if the All group didn't exist we probably don't have routes to update
continue
}
@@ -237,7 +236,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
}
// we need this persist to apply changes we made to account.Peers (we set them to Disconnected)
err = store.persist(ctx, store.storeFile)
err = store.persist(store.storeFile)
if err != nil {
return nil, err
}
@@ -247,7 +246,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
// persist account data to a file
// It is recommended to call it with locking FileStore.mux
func (s *FileStore) persist(ctx context.Context, file string) error {
func (s *FileStore) persist(file string) error {
start := time.Now()
err := util.WriteJson(file, s)
if err != nil {
@@ -257,23 +256,23 @@ func (s *FileStore) persist(ctx context.Context, file string) error {
if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took)
}
log.WithContext(ctx).Debugf("took %d ms to persist the FileStore", took.Milliseconds())
log.Debugf("took %d ms to persist the FileStore", took.Milliseconds())
return nil
}
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
func (s *FileStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
log.WithContext(ctx).Debugf("acquiring global lock")
func (s *FileStore) AcquireGlobalLock() (unlock func()) {
log.Debugf("acquiring global lock")
start := time.Now()
s.globalAccountLock.Lock()
unlock = func() {
s.globalAccountLock.Unlock()
log.WithContext(ctx).Debugf("released global lock in %v", time.Since(start))
log.Debugf("released global lock in %v", time.Since(start))
}
took := time.Since(start)
log.WithContext(ctx).Debugf("took %v to acquire global lock", took)
log.Debugf("took %v to acquire global lock", took)
if s.metrics != nil {
s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took)
}
@@ -282,8 +281,8 @@ func (s *FileStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
}
// AcquireAccountWriteLock acquires account lock for writing to a resource and returns a function that releases the lock
func (s *FileStore) AcquireAccountWriteLock(ctx context.Context, accountID string) (unlock func()) {
log.WithContext(ctx).Debugf("acquiring lock for account %s", accountID)
func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
log.Debugf("acquiring lock for account %s", accountID)
start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
mtx := value.(*sync.Mutex)
@@ -291,7 +290,7 @@ func (s *FileStore) AcquireAccountWriteLock(ctx context.Context, accountID strin
unlock = func() {
mtx.Unlock()
log.WithContext(ctx).Debugf("released lock for account %s in %v", accountID, time.Since(start))
log.Debugf("released lock for account %s in %v", accountID, time.Since(start))
}
return unlock
@@ -299,11 +298,11 @@ func (s *FileStore) AcquireAccountWriteLock(ctx context.Context, accountID strin
// AcquireAccountReadLock AcquireAccountWriteLock acquires account lock for reading a resource and returns a function that releases the lock
// This method is still returns a write lock as file store can't handle read locks
func (s *FileStore) AcquireAccountReadLock(ctx context.Context, accountID string) (unlock func()) {
return s.AcquireAccountWriteLock(ctx, accountID)
func (s *FileStore) AcquireAccountReadLock(accountID string) (unlock func()) {
return s.AcquireAccountWriteLock(accountID)
}
func (s *FileStore) SaveAccount(ctx context.Context, account *Account) error {
func (s *FileStore) SaveAccount(account *Account) error {
s.mux.Lock()
defer s.mux.Unlock()
@@ -339,10 +338,10 @@ func (s *FileStore) SaveAccount(ctx context.Context, account *Account) error {
s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id
}
return s.persist(ctx, s.storeFile)
return s.persist(s.storeFile)
}
func (s *FileStore) DeleteAccount(ctx context.Context, account *Account) error {
func (s *FileStore) DeleteAccount(account *Account) error {
s.mux.Lock()
defer s.mux.Unlock()
@@ -374,7 +373,7 @@ func (s *FileStore) DeleteAccount(ctx context.Context, account *Account) error {
delete(s.Accounts, account.Id)
return s.persist(ctx, s.storeFile)
return s.persist(s.storeFile)
}
// DeleteHashedPAT2TokenIDIndex removes an entry from the indexing map HashedPAT2TokenID
@@ -398,7 +397,7 @@ func (s *FileStore) DeleteTokenID2UserIDIndex(tokenID string) error {
}
// GetAccountByPrivateDomain returns account by private domain
func (s *FileStore) GetAccountByPrivateDomain(_ context.Context, domain string) (*Account, error) {
func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -416,7 +415,7 @@ func (s *FileStore) GetAccountByPrivateDomain(_ context.Context, domain string)
}
// GetAccountBySetupKey returns account by setup key id
func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*Account, error) {
func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -434,7 +433,7 @@ func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*A
}
// GetTokenIDByHashedToken returns the id of a personal access token by its hashed secret
func (s *FileStore) GetTokenIDByHashedToken(_ context.Context, token string) (string, error) {
func (s *FileStore) GetTokenIDByHashedToken(token string) (string, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -447,7 +446,7 @@ func (s *FileStore) GetTokenIDByHashedToken(_ context.Context, token string) (st
}
// GetUserByTokenID returns a User object a tokenID belongs to
func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User, error) {
func (s *FileStore) GetUserByTokenID(tokenID string) (*User, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -470,7 +469,7 @@ func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User,
}
// GetAllAccounts returns all accounts
func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
func (s *FileStore) GetAllAccounts() (all []*Account) {
s.mux.Lock()
defer s.mux.Unlock()
for _, a := range s.Accounts {
@@ -491,7 +490,7 @@ func (s *FileStore) getAccount(accountID string) (*Account, error) {
}
// GetAccount returns an account for ID
func (s *FileStore) GetAccount(_ context.Context, accountID string) (*Account, error) {
func (s *FileStore) GetAccount(accountID string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -504,7 +503,7 @@ func (s *FileStore) GetAccount(_ context.Context, accountID string) (*Account, e
}
// GetAccountByUser returns a user account
func (s *FileStore) GetAccountByUser(_ context.Context, userID string) (*Account, error) {
func (s *FileStore) GetAccountByUser(userID string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -522,7 +521,7 @@ func (s *FileStore) GetAccountByUser(_ context.Context, userID string) (*Account
}
// GetAccountByPeerID returns an account for a given peer ID
func (s *FileStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -540,7 +539,7 @@ func (s *FileStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acc
// check Account.Peers for a match
if _, ok := account.Peers[peerID]; !ok {
delete(s.PeerID2AccountID, peerID)
log.WithContext(ctx).Warnf("removed stale peerID %s to accountID %s index", peerID, accountID)
log.Warnf("removed stale peerID %s to accountID %s index", peerID, accountID)
return nil, status.NewPeerNotFoundError(peerID)
}
@@ -548,7 +547,7 @@ func (s *FileStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acc
}
// GetAccountByPeerPubKey returns an account for a given peer WireGuard public key
func (s *FileStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -573,14 +572,14 @@ func (s *FileStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string)
}
if stale {
delete(s.PeerKeyID2AccountID, peerKey)
log.WithContext(ctx).Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID)
log.Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID)
return nil, status.NewPeerNotFoundError(peerKey)
}
return account.Copy(), nil
}
func (s *FileStore) GetAccountIDByPeerPubKey(_ context.Context, peerKey string) (string, error) {
func (s *FileStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -604,7 +603,7 @@ func (s *FileStore) GetAccountIDByUserID(userID string) (string, error) {
return accountID, nil
}
func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (string, error) {
func (s *FileStore) GetAccountIDBySetupKey(setupKey string) (string, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -616,7 +615,7 @@ func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (
return accountID, nil
}
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbpeer.Peer, error) {
func (s *FileStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -639,7 +638,7 @@ func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbp
return nil, status.NewPeerNotFoundError(peerKey)
}
func (s *FileStore) GetAccountSettings(_ context.Context, accountID string) (*Settings, error) {
func (s *FileStore) GetAccountSettings(accountID string) (*Settings, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -657,13 +656,13 @@ func (s *FileStore) GetInstallationID() string {
}
// SaveInstallationID saves the installation ID
func (s *FileStore) SaveInstallationID(ctx context.Context, ID string) error {
func (s *FileStore) SaveInstallationID(ID string) error {
s.mux.Lock()
defer s.mux.Unlock()
s.InstallationID = ID
return s.persist(ctx, s.storeFile)
return s.persist(s.storeFile)
}
// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things.
@@ -733,13 +732,13 @@ func (s *FileStore) GetPostureCheckByChecksDefinition(accountID string, checks *
}
// Close the FileStore persisting data to disk
func (s *FileStore) Close(ctx context.Context) error {
func (s *FileStore) Close() error {
s.mux.Lock()
defer s.mux.Unlock()
log.WithContext(ctx).Infof("closing FileStore")
log.Infof("closing FileStore")
return s.persist(ctx, s.storeFile)
return s.persist(s.storeFile)
}
// GetStoreEngine returns FileStoreEngine

View File

@@ -1,7 +1,6 @@
package server
import (
"context"
"crypto/sha256"
"net"
"path/filepath"
@@ -28,12 +27,12 @@ func TestStalePeerIndices(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(context.Background(), storeDir, nil)
store, err := NewFileStore(storeDir, nil)
if err != nil {
return
}
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
peerID := "some_peer"
@@ -43,24 +42,24 @@ func TestStalePeerIndices(t *testing.T) {
Key: peerKey,
}
err = store.SaveAccount(context.Background(), account)
err = store.SaveAccount(account)
require.NoError(t, err)
account.DeletePeer(peerID)
err = store.SaveAccount(context.Background(), account)
err = store.SaveAccount(account)
require.NoError(t, err)
_, err = store.GetAccountByPeerID(context.Background(), peerID)
_, err = store.GetAccountByPeerID(peerID)
require.Error(t, err, "expecting to get an error when found stale index")
_, err = store.GetAccountByPeerPubKey(context.Background(), peerKey)
_, err = store.GetAccountByPeerPubKey(peerKey)
require.Error(t, err, "expecting to get an error when found stale index")
}
func TestNewStore(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
defer store.Close()
if store.Accounts == nil || len(store.Accounts) != 0 {
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
@@ -89,9 +88,9 @@ func TestNewStore(t *testing.T) {
func TestSaveAccount(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
defer store.Close()
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
account := newAccountWithId("account_id", "testuser", "")
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &nbpeer.Peer{
@@ -104,7 +103,7 @@ func TestSaveAccount(t *testing.T) {
}
// SaveAccount should trigger persist
err := store.SaveAccount(context.Background(), account)
err := store.SaveAccount(account)
if err != nil {
return
}
@@ -134,11 +133,11 @@ func TestDeleteAccount(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(context.Background(), storeDir, nil)
store, err := NewFileStore(storeDir, nil)
if err != nil {
t.Fatal(err)
}
defer store.Close(context.Background())
defer store.Close()
var account *Account
for _, a := range store.Accounts {
@@ -148,7 +147,7 @@ func TestDeleteAccount(t *testing.T) {
require.NotNil(t, account, "failed to restore a FileStore file and get at least one account")
err = store.DeleteAccount(context.Background(), account)
err = store.DeleteAccount(account)
require.NoError(t, err, "failed to delete account, error: %v", err)
_, ok := store.Accounts[account.Id]
@@ -184,9 +183,9 @@ func TestDeleteAccount(t *testing.T) {
func TestStore(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
defer store.Close()
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
account := newAccountWithId("account_id", "testuser", "")
account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey",
SetupKey: "peerkeysetupkey",
@@ -229,12 +228,12 @@ func TestStore(t *testing.T) {
})
// SaveAccount should trigger persist
err := store.SaveAccount(context.Background(), account)
err := store.SaveAccount(account)
if err != nil {
return
}
restored, err := NewFileStore(context.Background(), store.storeFile, nil)
restored, err := NewFileStore(store.storeFile, nil)
if err != nil {
return
}
@@ -282,7 +281,7 @@ func TestRestore(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(context.Background(), storeDir, nil)
store, err := NewFileStore(storeDir, nil)
if err != nil {
return
}
@@ -320,7 +319,7 @@ func TestRestoreGroups_Migration(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(context.Background(), storeDir, nil)
store, err := NewFileStore(storeDir, nil)
if err != nil {
return
}
@@ -333,11 +332,11 @@ func TestRestoreGroups_Migration(t *testing.T) {
Name: "All",
},
}
err = store.SaveAccount(context.Background(), account)
err = store.SaveAccount(account)
require.NoError(t, err, "failed to save account")
// restore account with default group with empty Issue field
if store, err = NewFileStore(context.Background(), storeDir, nil); err != nil {
if store, err = NewFileStore(storeDir, nil); err != nil {
return
}
account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
@@ -354,18 +353,18 @@ func TestGetAccountByPrivateDomain(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(context.Background(), storeDir, nil)
store, err := NewFileStore(storeDir, nil)
if err != nil {
return
}
existingDomain := "test.com"
account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain)
account, err := store.GetAccountByPrivateDomain(existingDomain)
require.NoError(t, err, "should found account")
require.Equal(t, existingDomain, account.Domain, "domains should match")
_, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com")
_, err = store.GetAccountByPrivateDomain("missing-domain.com")
require.Error(t, err, "should return error on domain lookup")
}
@@ -383,7 +382,7 @@ func TestFileStore_GetAccount(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(context.Background(), storeDir, nil)
store, err := NewFileStore(storeDir, nil)
if err != nil {
t.Fatal(err)
}
@@ -394,7 +393,7 @@ func TestFileStore_GetAccount(t *testing.T) {
return
}
account, err := store.GetAccount(context.Background(), expected.Id)
account, err := store.GetAccount(expected.Id)
if err != nil {
t.Fatal(err)
}
@@ -425,13 +424,13 @@ func TestFileStore_GetTokenIDByHashedToken(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(context.Background(), storeDir, nil)
store, err := NewFileStore(storeDir, nil)
if err != nil {
t.Fatal(err)
}
hashedToken := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].HashedToken
tokenID, err := store.GetTokenIDByHashedToken(context.Background(), hashedToken)
tokenID, err := store.GetTokenIDByHashedToken(hashedToken)
if err != nil {
t.Fatal(err)
}
@@ -442,7 +441,7 @@ func TestFileStore_GetTokenIDByHashedToken(t *testing.T) {
func TestFileStore_DeleteHashedPAT2TokenIDIndex(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
defer store.Close()
store.HashedPAT2TokenID["someHashedToken"] = "someTokenId"
err := store.DeleteHashedPAT2TokenIDIndex("someHashedToken")
@@ -479,13 +478,13 @@ func TestFileStore_GetTokenIDByHashedToken_Failure(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(context.Background(), storeDir, nil)
store, err := NewFileStore(storeDir, nil)
if err != nil {
t.Fatal(err)
}
wrongToken := sha256.Sum256([]byte("someNotValidTokenThatFails1234"))
_, err = store.GetTokenIDByHashedToken(context.Background(), string(wrongToken[:]))
_, err = store.GetTokenIDByHashedToken(string(wrongToken[:]))
assert.Error(t, err, "GetTokenIDByHashedToken should throw error if token invalid")
}
@@ -504,13 +503,13 @@ func TestFileStore_GetUserByTokenID(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(context.Background(), storeDir, nil)
store, err := NewFileStore(storeDir, nil)
if err != nil {
t.Fatal(err)
}
tokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID
user, err := store.GetUserByTokenID(context.Background(), tokenID)
user, err := store.GetUserByTokenID(tokenID)
if err != nil {
t.Fatal(err)
}
@@ -532,13 +531,13 @@ func TestFileStore_GetUserByTokenID_Failure(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(context.Background(), storeDir, nil)
store, err := NewFileStore(storeDir, nil)
if err != nil {
t.Fatal(err)
}
wrongTokenID := "someNonExistingTokenID"
_, err = store.GetUserByTokenID(context.Background(), wrongTokenID)
_, err = store.GetUserByTokenID(wrongTokenID)
assert.Error(t, err, "GetUserByTokenID should throw error if tokenID invalid")
}
@@ -551,7 +550,7 @@ func TestFileStore_SavePeerStatus(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(context.Background(), storeDir, nil)
store, err := NewFileStore(storeDir, nil)
if err != nil {
return
}
@@ -577,7 +576,7 @@ func TestFileStore_SavePeerStatus(t *testing.T) {
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()},
}
err = store.SaveAccount(context.Background(), account)
err = store.SaveAccount(account)
if err != nil {
t.Fatal(err)
}
@@ -603,11 +602,11 @@ func TestFileStore_SavePeerLocation(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(context.Background(), storeDir, nil)
store, err := NewFileStore(storeDir, nil)
if err != nil {
return
}
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
peer := &nbpeer.Peer{
@@ -626,7 +625,7 @@ func TestFileStore_SavePeerLocation(t *testing.T) {
assert.Error(t, err)
account.Peers[peer.ID] = peer
err = store.SaveAccount(context.Background(), account)
err = store.SaveAccount(account)
require.NoError(t, err)
peer.Location.ConnectionIP = net.ParseIP("35.1.1.1")
@@ -637,7 +636,7 @@ func TestFileStore_SavePeerLocation(t *testing.T) {
err = store.SavePeerLocation(account.Id, account.Peers[peer.ID])
assert.NoError(t, err)
account, err = store.GetAccount(context.Background(), account.Id)
account, err = store.GetAccount(account.Id)
require.NoError(t, err)
actual := account.Peers[peer.ID].Location
@@ -646,7 +645,7 @@ func TestFileStore_SavePeerLocation(t *testing.T) {
func newStore(t *testing.T) *FileStore {
t.Helper()
store, err := NewFileStore(context.Background(), t.TempDir(), nil)
store, err := NewFileStore(t.TempDir(), nil)
if err != nil {
t.Errorf("failed creating a new store")
}

View File

@@ -2,7 +2,6 @@ package geolocation
import (
"bytes"
"context"
"fmt"
"net"
"os"
@@ -53,7 +52,7 @@ type Country struct {
CountryName string
}
func NewGeolocation(ctx context.Context, dataDir string) (*Geolocation, error) {
func NewGeolocation(dataDir string) (*Geolocation, error) {
if err := loadGeolocationDatabases(dataDir); err != nil {
return nil, fmt.Errorf("failed to load MaxMind databases: %v", err)
}
@@ -69,7 +68,7 @@ func NewGeolocation(ctx context.Context, dataDir string) (*Geolocation, error) {
return nil, err
}
locationDB, err := NewSqliteStore(ctx, dataDir)
locationDB, err := NewSqliteStore(dataDir)
if err != nil {
return nil, err
}
@@ -84,7 +83,7 @@ func NewGeolocation(ctx context.Context, dataDir string) (*Geolocation, error) {
stopCh: make(chan struct{}),
}
go geo.reloader(ctx)
go geo.reloader()
return geo, nil
}
@@ -166,19 +165,19 @@ func (gl *Geolocation) Stop() error {
return nil
}
func (gl *Geolocation) reloader(ctx context.Context) {
func (gl *Geolocation) reloader() {
for {
select {
case <-gl.stopCh:
return
case <-time.After(gl.reloadCheckInterval):
if err := gl.locationDB.reload(ctx); err != nil {
log.WithContext(ctx).Errorf("geonames db reload failed: %s", err)
if err := gl.locationDB.reload(); err != nil {
log.Errorf("geonames db reload failed: %s", err)
}
newSha256sum1, err := calculateFileSHA256(gl.mmdbPath)
if err != nil {
log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
continue
}
if !bytes.Equal(gl.sha256sum, newSha256sum1) {
@@ -187,30 +186,30 @@ func (gl *Geolocation) reloader(ctx context.Context) {
time.Sleep(50 * time.Millisecond)
newSha256sum2, err := calculateFileSHA256(gl.mmdbPath)
if err != nil {
log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
continue
}
if !bytes.Equal(newSha256sum1, newSha256sum2) {
log.WithContext(ctx).Errorf("sha256 sum changed during reloading of '%s'", gl.mmdbPath)
log.Errorf("sha256 sum changed during reloading of '%s'", gl.mmdbPath)
continue
}
err = gl.reload(ctx, newSha256sum2)
err = gl.reload(newSha256sum2)
if err != nil {
log.WithContext(ctx).Errorf("mmdb reload failed: %s", err)
log.Errorf("mmdb reload failed: %s", err)
}
} else {
log.WithContext(ctx).Tracef("No changes in '%s', no need to reload. Next check is in %.0f seconds.",
log.Tracef("No changes in '%s', no need to reload. Next check is in %.0f seconds.",
gl.mmdbPath, gl.reloadCheckInterval.Seconds())
}
}
}
}
func (gl *Geolocation) reload(ctx context.Context, newSha256sum []byte) error {
func (gl *Geolocation) reload(newSha256sum []byte) error {
gl.mux.Lock()
defer gl.mux.Unlock()
log.WithContext(ctx).Infof("Reloading '%s'", gl.mmdbPath)
log.Infof("Reloading '%s'", gl.mmdbPath)
err := gl.db.Close()
if err != nil {
@@ -225,7 +224,7 @@ func (gl *Geolocation) reload(ctx context.Context, newSha256sum []byte) error {
gl.db = db
gl.sha256sum = newSha256sum
log.WithContext(ctx).Infof("Successfully reloaded '%s'", gl.mmdbPath)
log.Infof("Successfully reloaded '%s'", gl.mmdbPath)
return nil
}

View File

@@ -2,7 +2,6 @@ package geolocation
import (
"bytes"
"context"
"fmt"
"path/filepath"
"runtime"
@@ -51,10 +50,10 @@ type SqliteStore struct {
sha256sum []byte
}
func NewSqliteStore(ctx context.Context, dataDir string) (*SqliteStore, error) {
func NewSqliteStore(dataDir string) (*SqliteStore, error) {
file := filepath.Join(dataDir, GeoSqliteDBFile)
db, err := connectDB(ctx, file)
db, err := connectDB(file)
if err != nil {
return nil, err
}
@@ -116,13 +115,13 @@ func (s *SqliteStore) GetCitiesByCountry(countryISOCode string) ([]City, error)
}
// reload attempts to reload the SqliteStore's database if the database file has changed.
func (s *SqliteStore) reload(ctx context.Context) error {
func (s *SqliteStore) reload() error {
s.mux.Lock()
defer s.mux.Unlock()
newSha256sum1, err := calculateFileSHA256(s.filePath)
if err != nil {
log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err)
log.Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err)
}
if !bytes.Equal(s.sha256sum, newSha256sum1) {
@@ -137,11 +136,11 @@ func (s *SqliteStore) reload(ctx context.Context) error {
return fmt.Errorf("sha256 sum changed during reloading of '%s'", s.filePath)
}
log.WithContext(ctx).Infof("Reloading '%s'", s.filePath)
log.Infof("Reloading '%s'", s.filePath)
_ = s.close()
s.closed = true
newDb, err := connectDB(ctx, s.filePath)
newDb, err := connectDB(s.filePath)
if err != nil {
return err
}
@@ -149,9 +148,9 @@ func (s *SqliteStore) reload(ctx context.Context) error {
s.closed = false
s.db = newDb
log.WithContext(ctx).Infof("Successfully reloaded '%s'", s.filePath)
log.Infof("Successfully reloaded '%s'", s.filePath)
} else {
log.WithContext(ctx).Tracef("No changes in '%s', no need to reload", s.filePath)
log.Tracef("No changes in '%s', no need to reload", s.filePath)
}
return nil
@@ -169,10 +168,10 @@ func (s *SqliteStore) close() error {
}
// connectDB connects to an SQLite database and prepares it by setting up an in-memory database.
func connectDB(ctx context.Context, filePath string) (*gorm.DB, error) {
func connectDB(filePath string) (*gorm.DB, error) {
start := time.Now()
defer func() {
log.WithContext(ctx).Debugf("took %v to setup geoname db", time.Since(start))
log.Debugf("took %v to setup geoname db", time.Since(start))
}()
_, err := fileExists(filePath)

View File

@@ -1,7 +1,6 @@
package server
import (
"context"
"fmt"
"github.com/rs/xid"
@@ -22,11 +21,11 @@ func (e *GroupLinkError) Error() string {
}
// GetGroup object of the peers
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
@@ -49,11 +48,11 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
}
// GetAllGroups returns all groups in an account
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
@@ -76,11 +75,11 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID str
}
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
@@ -109,11 +108,11 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName,
}
// SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *nbgroup.Group) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
account, err := am.Store.GetAccount(accountID)
if err != nil {
return err
}
@@ -151,12 +150,11 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
account.Groups[newGroup.ID] = newGroup
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
if err = am.Store.SaveAccount(account); err != nil {
return err
}
// todo: check if groups is in use by dns, acl, routes and before/after peers
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(account)
// the following snippet tracks the activity and stores the group events in the event store.
// It has to happen after all the operations have been successfully performed.
@@ -167,16 +165,16 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
} else {
addedPeers = append(addedPeers, newGroup.Peers...)
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
am.StoreEvent(userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
}
for _, p := range addedPeers {
peer := account.Peers[p]
if peer == nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
log.Errorf("peer %s not found under account %s while saving group", p, accountID)
continue
}
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer,
am.StoreEvent(userID, peer.ID, accountID, activity.GroupAddedToPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
"peer_fqdn": peer.FQDN(am.GetDNSDomain()),
@@ -186,10 +184,10 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
for _, p := range removedPeers {
peer := account.Peers[p]
if peer == nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
log.Errorf("peer %s not found under account %s while saving group", p, accountID)
continue
}
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer,
am.StoreEvent(userID, peer.ID, accountID, activity.GroupRemovedFromPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
"peer_fqdn": peer.FQDN(am.GetDNSDomain()),
@@ -215,11 +213,11 @@ func difference(a, b []string) []string {
}
// DeleteGroup object of the peers
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountId)
func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountId)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountId)
account, err := am.Store.GetAccount(accountId)
if err != nil {
return err
}
@@ -317,24 +315,23 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
delete(account.Groups, groupID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
if err = am.Store.SaveAccount(account); err != nil {
return err
}
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
am.StoreEvent(userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
// todo: check if groups is in use by dns, acl, routes and if it has peers
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(account)
return nil
}
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
@@ -348,11 +345,11 @@ func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID strin
}
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
account, err := am.Store.GetAccount(accountID)
if err != nil {
return err
}
@@ -374,22 +371,21 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
if err = am.Store.SaveAccount(account); err != nil {
return err
}
// todo: check if groups is in use by dns, acl, routes
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(account)
return nil
}
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
account, err := am.Store.GetAccount(accountID)
if err != nil {
return err
}
@@ -403,14 +399,13 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
for i, itemID := range group.Peers {
if itemID == peerID {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
if err := am.Store.SaveAccount(ctx, account); err != nil {
if err := am.Store.SaveAccount(account); err != nil {
return err
}
}
}
// todo: check if groups is in use by dns, acl, routes
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(account)
return nil
}

View File

@@ -1,7 +1,6 @@
package server
import (
"context"
"errors"
"testing"
@@ -27,7 +26,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
}
for _, group := range account.Groups {
group.Issued = nbgroup.GroupIssuedIntegration
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group)
err = am.SaveGroup(account.Id, groupAdminUserID, group)
if err != nil {
t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedIntegration)
}
@@ -35,7 +34,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
for _, group := range account.Groups {
group.Issued = nbgroup.GroupIssuedJWT
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group)
err = am.SaveGroup(account.Id, groupAdminUserID, group)
if err != nil {
t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedJWT)
}
@@ -43,7 +42,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
for _, group := range account.Groups {
group.Issued = nbgroup.GroupIssuedAPI
group.ID = ""
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group)
err = am.SaveGroup(account.Id, groupAdminUserID, group)
if err == nil {
t.Errorf("should not create api group with the same name, %s", group.Name)
}
@@ -105,7 +104,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
err = am.DeleteGroup(context.Background(), account.Id, groupAdminUserID, testCase.groupID)
err = am.DeleteGroup(account.Id, groupAdminUserID, testCase.groupID)
if err == nil {
t.Errorf("delete %s group successfully", testCase.groupID)
return
@@ -226,7 +225,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
Id: "example user",
AutoGroups: []string{groupForUsers.ID},
}
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
account := newAccountWithId(accountID, groupAdminUserID, domain)
account.Routes[routeResource.ID] = routeResource
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
@@ -234,18 +233,18 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
account.SetupKeys[setupKey.Id] = setupKey
account.Users[user.Id] = user
err := am.Store.SaveAccount(context.Background(), account)
err := am.Store.SaveAccount(account)
if err != nil {
return nil, err
}
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForRoute)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForRoute2)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForNameServerGroups)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForPolicies)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForSetupKeys)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForUsers)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForIntegration)
return am.Store.GetAccount(context.Background(), account.Id)
return am.Store.GetAccount(account.Id)
}

View File

@@ -11,14 +11,12 @@ import (
pb "github.com/golang/protobuf/proto" // nolint
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/netbirdio/netbird/management/server/posture"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -42,7 +40,7 @@ type GRPCServer struct {
}
// NewServer creates a new Management server
func NewServer(ctx context.Context, config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) {
func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
@@ -52,7 +50,6 @@ func NewServer(ctx context.Context, config *Config, accountManager AccountManage
if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
jwtValidator, err = jwtclaims.NewJWTValidator(
ctx,
config.HttpConfig.AuthIssuer,
config.GetAuthAudiences(),
config.HttpConfig.AuthKeysLocation,
@@ -62,7 +59,7 @@ func NewServer(ctx context.Context, config *Config, accountManager AccountManage
return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)
}
} else {
log.WithContext(ctx).Debug("unable to use http config to create new jwt middleware")
log.Debug("unable to use http config to create new jwt middleware")
}
if appMetrics != nil {
@@ -129,61 +126,47 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequest()
}
ctx := srv.Context()
realIP := getRealIP(ctx)
realIP := getRealIP(srv.Context())
log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String())
syncReq := &proto.SyncRequest{}
peerKey, err := s.parseRequest(ctx, req, syncReq)
peerKey, err := s.parseRequest(req, syncReq)
if err != nil {
return err
}
//nolint
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
// this case should not happen and already indicates an issue but we don't want the system to fail due to being unable to log in detail
accountID = "UNKNOWN"
}
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String())
if syncReq.GetMeta() == nil {
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
log.Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
}
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), extractPeerMeta(syncReq.GetMeta()), realIP)
if err != nil {
return mapError(ctx, err)
return mapError(err)
}
err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv)
err = s.sendInitialSync(peerKey, peer, netMap, postureChecks, srv)
if err != nil {
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
log.Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
return err
}
updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID)
updates := s.peersUpdateManager.CreateChannel(peer.ID)
s.ephemeralManager.OnPeerConnected(ctx, peer)
s.ephemeralManager.OnPeerConnected(peer)
if s.config.TURNConfig.TimeBasedCredentials {
s.turnCredentialsManager.SetupRefresh(ctx, peer.ID)
s.turnCredentialsManager.SetupRefresh(peer.ID)
}
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
}
return s.handleUpdates(ctx, peerKey, peer, updates, srv)
return s.handleUpdates(peerKey, peer, updates, srv)
}
// handleUpdates sends updates to the connected peer until the updates channel is closed.
func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
func (s *GRPCServer) handleUpdates(peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
for {
select {
// condition when there are some updates
@@ -193,21 +176,21 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee
}
if !open {
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
s.cancelPeerRoutines(ctx, peer)
log.Debugf("updates channel for peer %s was closed", peerKey.String())
s.cancelPeerRoutines(peer)
return nil
}
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
log.Debugf("received an update for peer %s", peerKey.String())
if err := s.sendUpdate(ctx, peerKey, peer, update, srv); err != nil {
if err := s.sendUpdate(peerKey, peer, update, srv); err != nil {
return err
}
// condition when client <-> server connection has been terminated
case <-srv.Context().Done():
// happens when connection drops, e.g. client disconnects
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
s.cancelPeerRoutines(ctx, peer)
log.Debugf("stream of peer %s has been closed", peerKey.String())
s.cancelPeerRoutines(peer)
return srv.Context().Err()
}
}
@@ -215,10 +198,10 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server.
func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
func (s *GRPCServer) sendUpdate(peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
if err != nil {
s.cancelPeerRoutines(ctx, peer)
s.cancelPeerRoutines(peer)
return status.Errorf(codes.Internal, "failed processing update message")
}
err = srv.SendMsg(&proto.EncryptedMessage{
@@ -226,37 +209,37 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *
Body: encryptedResp,
})
if err != nil {
s.cancelPeerRoutines(ctx, peer)
s.cancelPeerRoutines(peer)
return status.Errorf(codes.Internal, "failed sending update message")
}
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
log.Debugf("sent an update to peer %s", peerKey.String())
return nil
}
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) {
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) {
s.peersUpdateManager.CloseChannel(peer.ID)
s.turnCredentialsManager.CancelRefresh(peer.ID)
_ = s.accountManager.CancelPeerRoutines(ctx, peer)
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
_ = s.accountManager.CancelPeerRoutines(peer)
s.ephemeralManager.OnPeerDisconnected(peer)
}
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
func (s *GRPCServer) validateToken(jwtToken string) (string, error) {
if s.jwtValidator == nil {
return "", status.Error(codes.Internal, "no jwt validator set")
}
token, err := s.jwtValidator.ValidateAndParse(ctx, jwtToken)
token, err := s.jwtValidator.ValidateAndParse(jwtToken)
if err != nil {
return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err)
}
claims := s.jwtClaimsExtractor.FromToken(token)
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
_, _, err = s.accountManager.GetAccountFromToken(ctx, claims)
_, _, err = s.accountManager.GetAccountFromToken(claims)
if err != nil {
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
}
if err := s.accountManager.CheckUserAccessByJWTGroups(ctx, claims); err != nil {
if err := s.accountManager.CheckUserAccessByJWTGroups(claims); err != nil {
return "", status.Errorf(codes.PermissionDenied, err.Error())
}
@@ -264,7 +247,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
}
// maps internal internalStatus.Error to gRPC status.Error
func mapError(ctx context.Context, err error) error {
func mapError(err error) error {
if e, ok := internalStatus.FromError(err); ok {
switch e.Type() {
case internalStatus.PermissionDenied:
@@ -280,11 +263,11 @@ func mapError(ctx context.Context, err error) error {
default:
}
}
log.WithContext(ctx).Errorf("got an unhandled error: %s", err)
log.Errorf("got an unhandled error: %s", err)
return status.Errorf(codes.Internal, "failed handling request")
}
func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
func extractPeerMeta(meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
if meta == nil {
return nbpeer.PeerSystemMeta{}
}
@@ -298,7 +281,7 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee
for _, addr := range meta.GetNetworkAddresses() {
netAddr, err := netip.ParsePrefix(addr.GetNetIP())
if err != nil {
log.WithContext(ctx).Warnf("failed to parse netip address, %s: %v", addr.GetNetIP(), err)
log.Warnf("failed to parse netip address, %s: %v", addr.GetNetIP(), err)
continue
}
networkAddresses = append(networkAddresses, nbpeer.NetworkAddress{
@@ -338,10 +321,10 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee
}
}
func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
func (s *GRPCServer) parseRequest(req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
log.WithContext(ctx).Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey)
log.Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey)
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey)
}
@@ -368,32 +351,22 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
s.appMetrics.GRPCMetrics().CountLoginRequest()
}
realIP := getRealIP(ctx)
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String())
log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String())
loginReq := &proto.LoginRequest{}
peerKey, err := s.parseRequest(ctx, req, loginReq)
peerKey, err := s.parseRequest(req, loginReq)
if err != nil {
return nil, err
}
//nolint
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
// this case should not happen and already indicates an issue but we don't want the system to fail due to being unable to log in detail
accountID = "UNKNOWN"
}
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
if loginReq.GetMeta() == nil {
msg := status.Errorf(codes.FailedPrecondition,
"peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(), realIP)
log.WithContext(ctx).Warn(msg)
log.Warn(msg)
return nil, msg
}
userID, err := s.processJwtToken(ctx, loginReq, peerKey)
userID, err := s.processJwtToken(loginReq, peerKey)
if err != nil {
return nil, err
}
@@ -403,33 +376,33 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
sshKey = loginReq.GetPeerKeys().GetSshPubKey()
}
peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, PeerLogin{
peer, netMap, postureChecks, err := s.accountManager.LoginPeer(PeerLogin{
WireGuardPubKey: peerKey.String(),
SSHKey: string(sshKey),
Meta: extractPeerMeta(ctx, loginReq.GetMeta()),
Meta: extractPeerMeta(loginReq.GetMeta()),
UserID: userID,
SetupKey: loginReq.GetSetupKey(),
ConnectionIP: realIP,
})
if err != nil {
log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err)
return nil, mapError(ctx, err)
log.Warnf("failed logging in peer %s: %s", peerKey, err)
return nil, mapError(err)
}
// if the login request contains setup key then it is a registration request
if loginReq.GetSetupKey() != "" {
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
s.ephemeralManager.OnPeerDisconnected(peer)
}
// if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil),
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()),
Checks: toProtocolChecks(ctx, postureChecks),
Checks: toProtocolChecks(postureChecks),
}
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
if err != nil {
log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID)
log.Warnf("failed encrypting peer %s message", peer.ID)
return nil, status.Errorf(codes.Internal, "failed logging in peer")
}
@@ -444,16 +417,16 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
//
// The user ID can be empty if the token is not provided, which is acceptable if the peer is already
// registered or if it uses a setup key to register.
func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) {
func (s *GRPCServer) processJwtToken(loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) {
userID := ""
if loginReq.GetJwtToken() != "" {
var err error
for i := 0; i < 3; i++ {
userID, err = s.validateToken(ctx, loginReq.GetJwtToken())
userID, err = s.validateToken(loginReq.GetJwtToken())
if err == nil {
break
}
log.WithContext(ctx).Warnf("failed validating JWT token sent from peer %s with error %v. "+
log.Warnf("failed validating JWT token sent from peer %s with error %v. "+
"Trying again as it may be due to the IdP cache issue", peerKey.String(), err)
time.Sleep(200 * time.Millisecond)
}
@@ -547,7 +520,7 @@ func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePee
return remotePeers
}
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse {
func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials)
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
@@ -578,7 +551,7 @@ func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turn
FirewallRules: firewallRules,
FirewallRulesIsEmpty: len(firewallRules) == 0,
},
Checks: toProtocolChecks(ctx, checks),
Checks: toProtocolChecks(checks),
}
}
@@ -588,7 +561,7 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em
}
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error {
func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error {
// make secret time based TURN credentials optional
var turnCredentials *TURNCredentials
if s.config.TURNConfig.TimeBasedCredentials {
@@ -597,7 +570,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
} else {
turnCredentials = nil
}
plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks)
plainResp := toSyncResponse(s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil {
@@ -610,7 +583,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
})
if err != nil {
log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err)
log.Errorf("failed sending SyncResponse %v", err)
return status.Errorf(codes.Internal, "error handling request")
}
@@ -624,14 +597,14 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey)
log.WithContext(ctx).Warn(errMSG)
log.Warn(errMSG)
return nil, status.Error(codes.InvalidArgument, errMSG)
}
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.DeviceAuthorizationFlowRequest{})
if err != nil {
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
log.WithContext(ctx).Warn(errMSG)
log.Warn(errMSG)
return nil, status.Error(codes.InvalidArgument, errMSG)
}
@@ -672,18 +645,18 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
// GetPKCEAuthorizationFlow returns a pkce authorization flow information
// This is used for initiating an Oauth 2 pkce authorization grant flow
// which will be used by our clients to Login
func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
func (s *GRPCServer) GetPKCEAuthorizationFlow(_ context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetPKCEAuthorizationFlow request.", req.WgPubKey)
log.WithContext(ctx).Warn(errMSG)
log.Warn(errMSG)
return nil, status.Error(codes.InvalidArgument, errMSG)
}
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.PKCEAuthorizationFlowRequest{})
if err != nil {
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
log.WithContext(ctx).Warn(errMSG)
log.Warn(errMSG)
return nil, status.Error(codes.InvalidArgument, errMSG)
}
@@ -719,10 +692,10 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En
// peer's under the same account of any updates.
func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
realIP := getRealIP(ctx)
log.WithContext(ctx).Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String())
log.Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String())
syncMetaReq := &proto.SyncMetaRequest{}
peerKey, err := s.parseRequest(ctx, req, syncMetaReq)
peerKey, err := s.parseRequest(req, syncMetaReq)
if err != nil {
return nil, err
}
@@ -730,20 +703,20 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage)
if syncMetaReq.GetMeta() == nil {
msg := status.Errorf(codes.FailedPrecondition,
"peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
log.WithContext(ctx).Warn(msg)
log.Warn(msg)
return nil, msg
}
err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()))
err = s.accountManager.SyncPeerMeta(peerKey.String(), extractPeerMeta(syncMetaReq.GetMeta()))
if err != nil {
return nil, mapError(ctx, err)
return nil, mapError(err)
}
return &proto.Empty{}, nil
}
// toProtocolChecks converts posture checks to protocol checks.
func toProtocolChecks(ctx context.Context, postureChecks []*posture.Checks) []*proto.Checks {
func toProtocolChecks(postureChecks []*posture.Checks) []*proto.Checks {
protoChecks := make([]*proto.Checks, 0, len(postureChecks))
for _, postureCheck := range postureChecks {
protoChecks = append(protoChecks, toProtocolCheck(postureCheck))

121
management/server/hash.go Normal file
View File

@@ -0,0 +1,121 @@
package server
import (
"github.com/mitchellh/hashstructure/v2"
"github.com/r3labs/diff"
log "github.com/sirupsen/logrus"
)
func updateAccountPeers(account *Account) {
//start := time.Now()
//defer func() {
// duration := time.Since(start)
// log.Printf("Finished execution of updateAccountPeers, took %v\n", duration)
//}()
peers := account.GetPeers()
approvedPeersMap := make(map[string]struct{}, len(peers))
for _, peer := range peers {
approvedPeersMap[peer.ID] = struct{}{}
}
for _, peer := range peers {
//if !am.peersUpdateManager.HasChannel(peer.ID) {
// log.Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
// continue
//}
_ = account.GetPeerNetworkMap(peer.ID, "netbird.io", approvedPeersMap)
//remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap)
//postureChecks := am.getPeerPostureChecks(account, peer)
//update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks)
//am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update})
}
}
func updateAccountPeersWithHash(account *Account) {
//start := time.Now()
//var skipUpdate int
//defer func() {
// duration := time.Since(start)
// log.Printf("Finished execution of updateAccountPeers, took %v\n", duration.Nanoseconds())
// log.Println("not updated peers: ", skipUpdate)
//}()
peers := account.GetPeers()
approvedPeersMap := make(map[string]struct{}, len(peers))
for _, peer := range peers {
approvedPeersMap[peer.ID] = struct{}{}
}
for _, peer := range peers {
//if !am.peersUpdateManager.HasChannel(peer.ID) {
// log.Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
// continue
//}
//33006042459
// 8700718125
remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, "netbird.io", approvedPeersMap)
//log.Println("firewall rules: ", len(remotePeerNetworkMap.FirewallRules))
hashStr, err := hashstructure.Hash(remotePeerNetworkMap, hashstructure.FormatV2, &hashstructure.HashOptions{
ZeroNil: true,
IgnoreZeroValue: true,
SlicesAsSets: true,
UseStringer: true,
//Hasher: xxhash.New(),
})
if err != nil {
log.Errorf("failed to generate network map hash: %v", err)
} else {
if peer.NetworkMapHash == hashStr {
//log.Debugf("not sending network map update to peer: %s as there is nothing new", peer.ID)
//skipUpdate++
continue
}
peer.NetworkMapHash = hashStr
}
}
}
func updateAccountPeersWithDiff(account *Account) {
//start := time.Now()
//var skipUpdate int
//defer func() {
// duration := time.Since(start)
// log.Printf("Finished execution of updateAccountPeers, took %v\n", duration.Nanoseconds())
// log.Println("not updated peers: ", skipUpdate)
//}()
peers := account.GetPeers()
approvedPeersMap := make(map[string]struct{}, len(peers))
for _, peer := range peers {
approvedPeersMap[peer.ID] = struct{}{}
}
for _, peer := range peers {
//if !am.peersUpdateManager.HasChannel(peer.ID) {
// log.Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
// continue
//}
//33006042459
// 8700718125
remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, "netbird.io", approvedPeersMap)
peer.NetworkMap = remotePeerNetworkMap
changelog, err := diff.Diff(peer.NetworkMap, remotePeerNetworkMap)
if err != nil {
log.Errorf("failed to generate network map diff: %v", err)
} else {
if len(changelog) == 0 {
continue
}
}
}
}
//48868101197
// 8700718125

View File

@@ -0,0 +1,424 @@
package server
import (
"fmt"
"net/netip"
"testing"
"time"
"github.com/mitchellh/hashstructure/v2"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
route2 "github.com/netbirdio/netbird/route"
"github.com/r3labs/diff"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)
func initTestAccount(b *testing.B, numPerAccount int) *Account {
b.Helper()
account := newAccountWithId("account_id", "testuser", "")
groupALL, err := account.GetGroupAll()
if err != nil {
b.Fatal(err)
}
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
for n := 0; n < numPerAccount; n++ {
netIP := randomIPv4()
peerID := fmt.Sprintf("%s-peer-%d", account.Id, n)
peer := &nbpeer.Peer{
ID: peerID,
Key: peerID,
SetupKey: "",
IP: netIP,
Name: peerID,
DNSLabel: peerID,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
account.Peers[peerID] = peer
group, _ := account.GetGroupAll()
group.Peers = append(group.Peers, peerID)
user := &User{
Id: fmt.Sprintf("%s-user-%d", account.Id, n),
AccountID: account.Id,
}
account.Users[user.Id] = user
route := &route2.Route{
ID: route2.ID(fmt.Sprintf("network-id-%d", n)),
Description: "base route",
NetID: route2.NetID(fmt.Sprintf("network-id-%d", n)),
Network: netip.MustParsePrefix(netIP.String() + "/24"),
NetworkType: route2.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
Groups: []string{groupALL.ID},
}
account.Routes[route.ID] = route
group = &nbgroup.Group{
ID: fmt.Sprintf("group-id-%d", n),
AccountID: account.Id,
Name: fmt.Sprintf("group-id-%d", n),
Issued: "api",
Peers: nil,
}
account.Groups[group.ID] = group
nameserver := &nbdns.NameServerGroup{
ID: fmt.Sprintf("nameserver-id-%d", n),
AccountID: account.Id,
Name: fmt.Sprintf("nameserver-id-%d", n),
Description: "",
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}},
Groups: []string{group.ID},
Primary: false,
Domains: nil,
Enabled: false,
SearchDomainsEnabled: false,
}
account.NameServerGroups[nameserver.ID] = nameserver
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
}
group := &nbgroup.Group{
ID: "randomID",
AccountID: account.Id,
Name: "randomName",
Issued: "api",
Peers: groupALL.Peers[:numPerAccount-1],
}
account.Groups[group.ID] = group
account.Policies = []*Policy{
{
ID: "RuleDefault",
Name: "Default",
Description: "This is a default rule that allows connections between all the resources",
Enabled: true,
Rules: []*PolicyRule{
{
ID: "RuleDefault",
Name: "Default",
Description: "This is a default rule that allows connections between all the resources",
Bidirectional: true,
Enabled: true,
Protocol: PolicyRuleProtocolTCP,
Action: PolicyTrafficActionAccept,
Sources: []string{
group.ID,
},
Destinations: []string{
group.ID,
},
},
{
ID: "RuleDefault2",
Name: "Default",
Description: "This is a default rule that allows connections between all the resources",
Bidirectional: true,
Enabled: true,
Protocol: PolicyRuleProtocolUDP,
Action: PolicyTrafficActionAccept,
Sources: []string{
groupALL.ID,
},
Destinations: []string{
groupALL.ID,
},
},
},
},
}
return account
}
// 1000 - 6717416375 ns/op
// 500 - 1732888875 ns/op
func BenchmarkTest_updateAccountPeers100(b *testing.B) {
account := initTestAccount(b, 100)
b.ResetTimer()
for i := 0; i < b.N; i++ {
updateAccountPeers(account)
}
}
// 1000 - 28943404000 ns/op
// 500 - 7365024500 ns/op
func BenchmarkTest_updateAccountPeersWithHash100(b *testing.B) {
account := initTestAccount(b, 100)
b.ResetTimer()
for i := 0; i < b.N; i++ {
log.Debug(i)
updateAccountPeersWithHash(account)
}
}
func BenchmarkTest_updateAccountPeersWithDiff100(b *testing.B) {
account := initTestAccount(b, 100)
b.ResetTimer()
for i := 0; i < b.N; i++ {
log.Debug(i)
updateAccountPeersWithDiff(account)
}
}
// 1000 - 6717416375 ns/op
// 500 - 1732888875 ns/op
func BenchmarkTest_updateAccountPeers200(b *testing.B) {
account := initTestAccount(b, 200)
b.ResetTimer()
for i := 0; i < b.N; i++ {
updateAccountPeers(account)
}
}
// 1000 - 28943404000 ns/op
// 500 - 7365024500 ns/op
func BenchmarkTest_updateAccountPeersWithHash200(b *testing.B) {
account := initTestAccount(b, 200)
b.ResetTimer()
for i := 0; i < b.N; i++ {
log.Debug(i)
updateAccountPeersWithHash(account)
}
}
func BenchmarkTest_updateAccountPeersWithDiff200(b *testing.B) {
account := initTestAccount(b, 200)
b.ResetTimer()
for i := 0; i < b.N; i++ {
log.Debug(i)
updateAccountPeersWithDiff(account)
}
}
func BenchmarkTest_updateAccountPeers500(b *testing.B) {
account := initTestAccount(b, 500)
b.ResetTimer()
for i := 0; i < b.N; i++ {
updateAccountPeers(account)
}
}
// 1000 - 28943404000 ns/op
// 500 - 7365024500 ns/op
func BenchmarkTest_updateAccountPeersWithHash500(b *testing.B) {
account := initTestAccount(b, 500)
b.ResetTimer()
for i := 0; i < b.N; i++ {
log.Debug(i)
updateAccountPeersWithHash(account)
}
}
func BenchmarkTest_updateAccountPeersWithDiff500(b *testing.B) {
account := initTestAccount(b, 500)
b.ResetTimer()
for i := 0; i < b.N; i++ {
log.Debug(i)
updateAccountPeersWithDiff(account)
}
}
func BenchmarkTest_updateAccountPeers1000(b *testing.B) {
account := initTestAccount(b, 1000)
b.ResetTimer()
for i := 0; i < b.N; i++ {
updateAccountPeers(account)
}
}
// 1000 - 28943404000 ns/op
// 500 - 7365024500 ns/op
func BenchmarkTest_updateAccountPeersWithHash1000(b *testing.B) {
account := initTestAccount(b, 1000)
b.ResetTimer()
for i := 0; i < b.N; i++ {
log.Debug(i)
updateAccountPeersWithHash(account)
}
}
func BenchmarkTest_updateAccountPeersWithDiff1000(b *testing.B) {
account := initTestAccount(b, 1000)
b.ResetTimer()
for i := 0; i < b.N; i++ {
log.Debug(i)
updateAccountPeersWithDiff(account)
}
}
func BenchmarkTest_updateAccountPeers1500(b *testing.B) {
account := initTestAccount(b, 1500)
b.ResetTimer()
for i := 0; i < b.N; i++ {
updateAccountPeers(account)
}
}
// 1000 - 28943404000 ns/op
// 500 - 7365024500 ns/op
func BenchmarkTest_updateAccountPeersWithHash1500(b *testing.B) {
account := initTestAccount(b, 1500)
b.ResetTimer()
for i := 0; i < b.N; i++ {
log.Debug(i)
updateAccountPeersWithHash(account)
}
}
func BenchmarkTest_updateAccountPeersWithDiff1500(b *testing.B) {
account := initTestAccount(b, 1500)
b.ResetTimer()
for i := 0; i < b.N; i++ {
log.Debug(i)
updateAccountPeersWithDiff(account)
}
}
func BenchmarkTest_updateAccountPeers2000(b *testing.B) {
account := initTestAccount(b, 2000)
b.ResetTimer()
for i := 0; i < b.N; i++ {
updateAccountPeers(account)
}
}
// 1000 - 28943404000 ns/op
// 500 - 7365024500 ns/op
func BenchmarkTest_updateAccountPeersWithHash2000(b *testing.B) {
account := initTestAccount(b, 2000)
b.ResetTimer()
for i := 0; i < b.N; i++ {
log.Debug(i)
updateAccountPeersWithHash(account)
}
}
func BenchmarkTest_updateAccountPeersWithDiff2000(b *testing.B) {
account := initTestAccount(b, 2000)
b.ResetTimer()
for i := 0; i < b.N; i++ {
log.Debug(i)
updateAccountPeersWithDiff(account)
}
}
type TestStruct struct {
Name string
Value int
Ignored string `diff:"-" hash:"ignore"`
Compared string
}
func TestDiffIgnoreTag(t *testing.T) {
a := TestStruct{
Name: "test",
Value: 30,
Ignored: "This should be ignored",
Compared: "This should be compared",
}
b := TestStruct{
Name: "test",
Value: 31,
Ignored: "This is different but should be ignored",
Compared: "This is different and should be compared",
}
changelog, err := diff.Diff(a, b)
assert.NoError(t, err)
// Check that only the expected fields are in the changelog
assert.Len(t, changelog, 2)
// Check that the 'Age' field change is detected
ageChange := getChangeForField(changelog, "Value")
assert.NotNil(t, ageChange)
assert.Equal(t, 30, ageChange.From)
assert.Equal(t, 31, ageChange.To)
// Check that the 'Compared' field change is detected
comparedChange := getChangeForField(changelog, "Compared")
assert.NotNil(t, comparedChange)
assert.Equal(t, "This should be compared", comparedChange.From)
assert.Equal(t, "This is different and should be compared", comparedChange.To)
// Check that the 'Ignored' field is not in the changelog
ignoredChange := getChangeForField(changelog, "Ignored")
assert.Nil(t, ignoredChange)
}
func TestHashIgnoreTag(t *testing.T) {
a := TestStruct{
Name: "test",
Value: 30,
Ignored: "This should be ignored",
Compared: "This should be compared",
}
b := TestStruct{
Name: "test",
Value: 30,
Ignored: "This is different but should be ignored",
Compared: "This should be compared",
}
c := TestStruct{
Name: "test",
Value: 31,
Ignored: "This should be ignored",
Compared: "This should be compared",
}
d := TestStruct{
Name: "test",
Value: 30,
Ignored: "This should be ignored",
Compared: "This is different and should be compared",
}
opts := &hashstructure.HashOptions{
ZeroNil: true,
IgnoreZeroValue: true,
SlicesAsSets: true,
UseStringer: true,
}
hashA, err := hashstructure.Hash(a, hashstructure.FormatV2, opts)
assert.NoError(t, err)
hashB, err := hashstructure.Hash(b, hashstructure.FormatV2, opts)
assert.NoError(t, err)
hashC, err := hashstructure.Hash(c, hashstructure.FormatV2, opts)
assert.NoError(t, err)
hashD, err := hashstructure.Hash(d, hashstructure.FormatV2, opts)
assert.NoError(t, err)
// Test that changing the ignored field does not change the hash
assert.Equal(t, hashA, hashB)
// Test that changing a non-ignored field does change the hash
assert.NotEqual(t, hashA, hashC)
assert.NotEqual(t, hashA, hashD)
}
func getChangeForField(changelog diff.Changelog, fieldName string) *diff.Change {
for _, change := range changelog {
if change.Path[0] == fieldName {
return &change
}
}
return nil
}

View File

@@ -35,34 +35,34 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *
// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
if !(user.HasAdminPower() || user.IsServiceUser) {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w)
util.WriteError(status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w)
return
}
resp := toAccountResponse(account)
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
util.WriteJSONObject(w, []*api.Account{resp})
}
// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
_, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
_, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
accountID := vars["accountId"]
if len(accountID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid accountID ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid accountID ID"), w)
return
}
@@ -96,15 +96,15 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
}
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, user.Id, settings)
updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, settings)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
resp := toAccountResponse(updatedAccount)
util.WriteJSONObject(r.Context(), w, &resp)
util.WriteJSONObject(w, &resp)
}
// DeleteAccount is a HTTP DELETE handler to delete an account
@@ -118,17 +118,17 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request)
vars := mux.Vars(r)
targetAccountID := vars["accountId"]
if len(targetAccountID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid account ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid account ID"), w)
return
}
err := h.accountManager.DeleteAccount(r.Context(), targetAccountID, claims.UserId)
err := h.accountManager.DeleteAccount(targetAccountID, claims.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
util.WriteJSONObject(r.Context(), w, emptyObject{})
util.WriteJSONObject(w, emptyObject{})
}
func toAccountResponse(account *server.Account) *api.Account {

View File

@@ -2,7 +2,6 @@ package http
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@@ -23,10 +22,10 @@ import (
func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler {
return &AccountsHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return account, admin, nil
},
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
UpdateAccountSettingsFunc: func(accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit {
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")

View File

@@ -32,16 +32,16 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg
// GetDNSSettings returns the DNS settings for the account
func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.WithContext(r.Context()).Error(err)
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id)
dnsSettings, err := h.accountManager.GetDNSSettings(account.Id, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
@@ -49,15 +49,15 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque
DisabledManagementGroups: dnsSettings.DisabledManagementGroups,
}
util.WriteJSONObject(r.Context(), w, apiDNSSettings)
util.WriteJSONObject(w, apiDNSSettings)
}
// UpdateDNSSettings handles update to DNS settings of an account
func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
@@ -72,9 +72,9 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re
DisabledManagementGroups: req.DisabledManagementGroups,
}
err = h.accountManager.SaveDNSSettings(r.Context(), account.Id, user.Id, updateDNSSettings)
err = h.accountManager.SaveDNSSettings(account.Id, user.Id, updateDNSSettings)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
@@ -82,5 +82,5 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re
DisabledManagementGroups: updateDNSSettings.DisabledManagementGroups,
}
util.WriteJSONObject(r.Context(), w, &resp)
util.WriteJSONObject(w, &resp)
}

View File

@@ -2,7 +2,6 @@ package http
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@@ -43,16 +42,16 @@ var testingDNSSettingsAccount = &server.Account{
func initDNSSettingsTestData() *DNSSettingsHandler {
return &DNSSettingsHandler{
accountManager: &mock_server.MockAccountManager{
GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) {
GetDNSSettingsFunc: func(accountID string, userID string) (*server.DNSSettings, error) {
return &testingDNSSettingsAccount.DNSSettings, nil
},
SaveDNSSettingsFunc: func(ctx context.Context, accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error {
SaveDNSSettingsFunc: func(accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error {
if dnsSettingsToSave != nil {
return nil
}
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
},
GetAccountFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil
},
},

View File

@@ -1,7 +1,6 @@
package http
import (
"context"
"fmt"
"net/http"
@@ -34,16 +33,16 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev
// GetAllEvents list of the given account
func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.WithContext(r.Context()).Error(err)
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountEvents, err := h.accountManager.GetEvents(r.Context(), account.Id, user.Id)
accountEvents, err := h.accountManager.GetEvents(account.Id, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
events := make([]*api.Event, len(accountEvents))
@@ -51,20 +50,20 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
events[i] = toEventResponse(e)
}
err = h.fillEventsWithUserInfo(r.Context(), events, account.Id, user.Id)
err = h.fillEventsWithUserInfo(events, account.Id, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
util.WriteJSONObject(r.Context(), w, events)
util.WriteJSONObject(w, events)
}
func (h *EventsHandler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error {
func (h *EventsHandler) fillEventsWithUserInfo(events []*api.Event, accountId, userId string) error {
// build email, name maps based on users
userInfos, err := h.accountManager.GetUsersFromAccount(ctx, accountId, userId)
userInfos, err := h.accountManager.GetUsersFromAccount(accountId, userId)
if err != nil {
log.WithContext(ctx).Errorf("failed to get users from account: %s", err)
log.Errorf("failed to get users from account: %s", err)
return err
}
@@ -81,7 +80,7 @@ func (h *EventsHandler) fillEventsWithUserInfo(ctx context.Context, events []*ap
if event.InitiatorEmail == "" {
event.InitiatorEmail, ok = emails[event.InitiatorId]
if !ok {
log.WithContext(ctx).Warnf("failed to resolve email for initiator: %s", event.InitiatorId)
log.Warnf("failed to resolve email for initiator: %s", event.InitiatorId)
}
}

View File

@@ -1,7 +1,6 @@
package http
import (
"context"
"encoding/json"
"io"
"net/http"
@@ -23,13 +22,13 @@ import (
func initEventsTestData(account string, user *server.User, events ...*activity.Event) *EventsHandler {
return &EventsHandler{
accountManager: &mock_server.MockAccountManager{
GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) {
GetEventsFunc: func(accountID, userID string) ([]*activity.Event, error) {
if accountID == account {
return events, nil
}
return []*activity.Event{}, nil
},
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
@@ -38,7 +37,7 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E
},
}, user, nil
},
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
return make([]*server.UserInfo, 0), nil
},
},

View File

@@ -1,7 +1,6 @@
package http
import (
"context"
"encoding/json"
"io"
"net/http"
@@ -36,13 +35,13 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
err = util.CopyFileContents(geonamesDBPath, path.Join(tempDir, geolocation.GeoSqliteDBFile))
assert.NoError(t, err)
geo, err := geolocation.NewGeolocation(context.Background(), tempDir)
geo, err := geolocation.NewGeolocation(tempDir)
assert.NoError(t, err)
t.Cleanup(func() { _ = geo.Stop() })
return &GeolocationsHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,

View File

@@ -40,19 +40,19 @@ func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geoloca
// GetAllCountries retrieves a list of all countries
func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
if l.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
return
}
allCountries, err := l.geolocationManager.GetAllCountries()
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
@@ -60,32 +60,32 @@ func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Req
for _, country := range allCountries {
countries = append(countries, toCountryResponse(country))
}
util.WriteJSONObject(r.Context(), w, countries)
util.WriteJSONObject(w, countries)
}
// GetCitiesByCountry retrieves a list of cities based on the given country code
func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
countryCode := vars["country"]
if !countryCodeRegex.MatchString(countryCode) {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid country code"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid country code"), w)
return
}
if l.geolocationManager == nil {
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
"Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w)
return
}
allCities, err := l.geolocationManager.GetCitiesByCountry(countryCode)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
@@ -93,12 +93,12 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.
for _, city := range allCities {
cities = append(cities, toCityResponse(city))
}
util.WriteJSONObject(r.Context(), w, cities)
util.WriteJSONObject(w, cities)
}
func (l *GeolocationsHandler) authenticateUser(r *http.Request) error {
claims := l.claimsExtractor.FromRequestContext(r)
_, user, err := l.accountManager.GetAccountFromToken(r.Context(), claims)
_, user, err := l.accountManager.GetAccountFromToken(claims)
if err != nil {
return err
}

View File

@@ -35,16 +35,16 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr
// GetAllGroups list for the account
func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.WithContext(r.Context()).Error(err)
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
groups, err := h.accountManager.GetAllGroups(r.Context(), account.Id, user.Id)
groups, err := h.accountManager.GetAllGroups(account.Id, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
@@ -53,42 +53,42 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
groupsResponse = append(groupsResponse, toGroupResponse(account, group))
}
util.WriteJSONObject(r.Context(), w, groupsResponse)
util.WriteJSONObject(w, groupsResponse)
}
// UpdateGroup handles update to a group identified by a given ID
func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
groupID, ok := vars["groupId"]
if !ok {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group ID field is missing"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "group ID field is missing"), w)
return
}
if len(groupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group ID can't be empty"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "group ID can't be empty"), w)
return
}
eg, ok := account.Groups[groupID]
if !ok {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
util.WriteError(status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
return
}
allGroup, err := account.GetGroupAll()
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
if allGroup.ID == groupID {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
return
}
@@ -100,7 +100,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
}
if req.Name == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
return
}
@@ -118,21 +118,21 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
IntegrationReference: eg.IntegrationReference,
}
if err := h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group); err != nil {
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
util.WriteError(r.Context(), err, w)
if err := h.accountManager.SaveGroup(account.Id, user.Id, &group); err != nil {
log.Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
util.WriteError(err, w)
return
}
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
util.WriteJSONObject(w, toGroupResponse(account, &group))
}
// CreateGroup handles group creation request
func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
@@ -144,7 +144,7 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
}
if req.Name == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
return
}
@@ -160,62 +160,62 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
Issued: nbgroup.GroupIssuedAPI,
}
err = h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group)
err = h.accountManager.SaveGroup(account.Id, user.Id, &group)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
util.WriteJSONObject(w, toGroupResponse(account, &group))
}
// DeleteGroup handles group deletion request
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
aID := account.Id
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
}
allGroup, err := account.GetGroupAll()
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
if allGroup.ID == groupID {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
return
}
err = h.accountManager.DeleteGroup(r.Context(), aID, user.Id, groupID)
err = h.accountManager.DeleteGroup(aID, user.Id, groupID)
if err != nil {
_, ok := err.(*server.GroupLinkError)
if ok {
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
return
}
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
util.WriteJSONObject(r.Context(), w, emptyObject{})
util.WriteJSONObject(w, emptyObject{})
}
// GetGroup returns a group
func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
@@ -223,19 +223,19 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
case http.MethodGet:
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
}
group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, user.Id)
group, err := h.accountManager.GetGroup(account.Id, groupID, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group))
util.WriteJSONObject(w, toGroupResponse(account, group))
default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "HTTP method not found"), w)
util.WriteError(status.Errorf(status.NotFound, "HTTP method not found"), w)
return
}
}

View File

@@ -2,7 +2,6 @@ package http
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@@ -33,13 +32,13 @@ var TestPeers = map[string]*nbpeer.Peer{
func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
return &GroupsHandler{
accountManager: &mock_server.MockAccountManager{
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error {
SaveGroupFunc: func(accountID, userID string, group *nbgroup.Group) error {
if !strings.HasPrefix(group.ID, "id-") {
group.ID = "id-was-set"
}
return nil
},
GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) {
GetGroupFunc: func(_, groupID, _ string) (*nbgroup.Group, error) {
if groupID != "idofthegroup" {
return nil, status.Errorf(status.NotFound, "not found")
}
@@ -56,7 +55,7 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
Issued: nbgroup.GroupIssuedAPI,
}, nil
},
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
@@ -71,7 +70,7 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
},
}, user, nil
},
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
DeleteGroupFunc: func(accountID, userId, groupID string) error {
if groupID == "linked-grp" {
return &server.GroupLinkError{
Resource: "something",

View File

@@ -9,7 +9,6 @@ import (
"github.com/rs/cors"
"github.com/netbirdio/management-integrations/integrations"
s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/middleware"
@@ -58,11 +57,6 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
corsMiddleware := cors.AllowAll()
claimsExtractor = jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
)
acMiddleware := middleware.NewAccessControl(
authCfg.Audience,
authCfg.UserIDClaim,

View File

@@ -1,7 +1,6 @@
package middleware
import (
"context"
"net/http"
"regexp"
@@ -16,7 +15,7 @@ import (
)
// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
type GetUser func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
type AccessControl struct {
@@ -47,15 +46,15 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
claims := a.claimsExtract.FromRequestContext(r)
user, err := a.getUser(r.Context(), claims)
user, err := a.getUser(claims)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to get user from claims: %s", err)
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid JWT"), w)
log.Errorf("failed to get user from claims: %s", err)
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
return
}
if user.IsBlocked() {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w)
util.WriteError(status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w)
return
}
@@ -64,12 +63,12 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
if tokenPathRegexp.MatchString(r.URL.Path) {
log.WithContext(r.Context()).Debugf("valid Path")
log.Debugf("valid Path")
h.ServeHTTP(w, r)
return
}
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only users with admin power can perform this operation"), w)
util.WriteError(status.Errorf(status.PermissionDenied, "only users with admin power can perform this operation"), w)
return
}
}

View File

@@ -12,7 +12,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -20,16 +19,16 @@ import (
)
// GetAccountFromPATFunc function
type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
type GetAccountFromPATFunc func(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
// ValidateAndParseTokenFunc function
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error)
// MarkPATUsedFunc function
type MarkPATUsedFunc func(ctx context.Context, token string) error
type MarkPATUsedFunc func(token string) error
// CheckUserAccessByJWTGroupsFunc function
type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
type CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
@@ -86,27 +85,23 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
case "bearer":
err := m.checkJWTFromRequest(w, r, auth)
if err != nil {
log.WithContext(r.Context()).Errorf("Error when validating JWT claims: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
log.Errorf("Error when validating JWT claims: %s", err.Error())
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, r)
case "token":
err := m.checkPATFromRequest(w, r, auth)
if err != nil {
log.WithContext(r.Context()).Debugf("Error when validating PAT claims: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
log.Debugf("Error when validating PAT claims: %s", err.Error())
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, r)
default:
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
util.WriteError(status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
return
}
claims := m.claimsExtractor.FromRequestContext(r)
//nolint
ctx := context.WithValue(r.Context(), nbContext.UserIDKey, claims.UserId)
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, claims.AccountId)
h.ServeHTTP(w, r.WithContext(ctx))
})
}
@@ -119,7 +114,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
return fmt.Errorf("Error extracting token: %w", err)
}
validatedToken, err := m.validateAndParseToken(r.Context(), token)
validatedToken, err := m.validateAndParseToken(token)
if err != nil {
return err
}
@@ -128,7 +123,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
return nil
}
if err := m.verifyUserAccess(r.Context(), validatedToken); err != nil {
if err := m.verifyUserAccess(validatedToken); err != nil {
return err
}
@@ -143,9 +138,9 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
// verifyUserAccess checks if a user, based on a validated JWT token,
// is allowed access, particularly in cases where the admin enabled JWT
// group propagation and designated certain groups with access permissions.
func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *jwt.Token) error {
func (m *AuthMiddleware) verifyUserAccess(validatedToken *jwt.Token) error {
authClaims := m.claimsExtractor.FromToken(validatedToken)
return m.checkUserAccessByJWTGroups(ctx, authClaims)
return m.checkUserAccessByJWTGroups(authClaims)
}
// CheckPATFromRequest checks if the PAT is valid
@@ -157,7 +152,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
return fmt.Errorf("Error extracting token: %w", err)
}
account, user, pat, err := m.getAccountFromPAT(r.Context(), token)
account, user, pat, err := m.getAccountFromPAT(token)
if err != nil {
return fmt.Errorf("invalid Token: %w", err)
}
@@ -165,7 +160,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
return fmt.Errorf("token expired")
}
err = m.markPATUsed(r.Context(), pat.ID)
err = m.markPATUsed(pat.ID)
if err != nil {
return err
}

View File

@@ -1,7 +1,6 @@
package middleware
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
@@ -16,16 +15,15 @@ import (
)
const (
audience = "audience"
userIDClaim = "userIDClaim"
accountID = "accountID"
domain = "domain"
domainCategory = "domainCategory"
userID = "userID"
tokenID = "tokenID"
PAT = "nbp_PAT"
JWT = "JWT"
wrongToken = "wrongToken"
audience = "audience"
userIDClaim = "userIDClaim"
accountID = "accountID"
domain = "domain"
userID = "userID"
tokenID = "tokenID"
PAT = "nbp_PAT"
JWT = "JWT"
wrongToken = "wrongToken"
)
var testAccount = &server.Account{
@@ -49,14 +47,14 @@ var testAccount = &server.Account{
},
}
func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
if token == PAT {
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
}
return nil, nil, nil, fmt.Errorf("PAT invalid")
}
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
func mockValidateAndParseToken(token string) (*jwt.Token, error) {
if token == JWT {
return &jwt.Token{
Claims: jwt.MapClaims{
@@ -69,14 +67,14 @@ func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, err
return nil, fmt.Errorf("JWT invalid")
}
func mockMarkPATUsed(_ context.Context, token string) error {
func mockMarkPATUsed(token string) error {
if token == tokenID {
return nil
}
return fmt.Errorf("Should never get reached")
}
func mockCheckUserAccessByJWTGroups(_ context.Context, claims jwtclaims.AuthorizationClaims) error {
func mockCheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error {
if testAccount.Id != claims.AccountId {
return fmt.Errorf("account with id %s does not exist", claims.AccountId)
}

View File

@@ -56,7 +56,7 @@ func ShouldBypass(requestPath string, h http.Handler, w http.ResponseWriter, r *
for bypassPath := range bypassPaths {
matched, err := path.Match(bypassPath, requestPath)
if err != nil {
log.WithContext(r.Context()).Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err)
log.Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err)
continue
}
if matched {

View File

@@ -36,16 +36,16 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg
// GetAllNameservers returns the list of nameserver groups for the account
func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.WithContext(r.Context()).Error(err)
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), account.Id, user.Id)
nsGroups, err := h.accountManager.ListNameServerGroups(account.Id, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
@@ -54,15 +54,15 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re
apiNameservers = append(apiNameservers, toNameserverGroupResponse(r))
}
util.WriteJSONObject(r.Context(), w, apiNameservers)
util.WriteJSONObject(w, apiNameservers)
}
// CreateNameserverGroup handles nameserver group creation request
func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
@@ -75,33 +75,33 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
nsList, err := toServerNSList(req.Nameservers)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
return
}
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled)
nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
resp := toNameserverGroupResponse(nsGroup)
util.WriteJSONObject(r.Context(), w, &resp)
util.WriteJSONObject(w, &resp)
}
// UpdateNameserverGroup handles update to a nameserver group identified by a given ID
func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return
}
@@ -114,7 +114,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
nsList, err := toServerNSList(req.Nameservers)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
return
}
@@ -130,66 +130,66 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
SearchDomainsEnabled: req.SearchDomainsEnabled,
}
err = h.accountManager.SaveNameServerGroup(r.Context(), account.Id, user.Id, updatedNSGroup)
err = h.accountManager.SaveNameServerGroup(account.Id, user.Id, updatedNSGroup)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
resp := toNameserverGroupResponse(updatedNSGroup)
util.WriteJSONObject(r.Context(), w, &resp)
util.WriteJSONObject(w, &resp)
}
// DeleteNameserverGroup handles nameserver group deletion request
func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return
}
err = h.accountManager.DeleteNameServerGroup(r.Context(), account.Id, nsGroupID, user.Id)
err = h.accountManager.DeleteNameServerGroup(account.Id, nsGroupID, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
util.WriteJSONObject(r.Context(), w, emptyObject{})
util.WriteJSONObject(w, emptyObject{})
}
// GetNameserverGroup handles a nameserver group Get request identified by ID
func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.WithContext(r.Context()).Error(err)
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return
}
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), account.Id, user.Id, nsGroupID)
nsGroup, err := h.accountManager.GetNameServerGroup(account.Id, user.Id, nsGroupID)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
resp := toNameserverGroupResponse(nsGroup)
util.WriteJSONObject(r.Context(), w, &resp)
util.WriteJSONObject(w, &resp)
}
func toServerNSList(apiNSList []api.Nameserver) ([]nbdns.NameServer, error) {

View File

@@ -2,7 +2,6 @@ package http
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@@ -62,13 +61,13 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{
func initNameserversTestData() *NameserversHandler {
return &NameserversHandler{
accountManager: &mock_server.MockAccountManager{
GetNameServerGroupFunc: func(_ context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
GetNameServerGroupFunc: func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
if nsGroupID == existingNSGroupID {
return baseExistingNSGroup.Copy(), nil
}
return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
},
CreateNameServerGroupFunc: func(_ context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, _ string, searchDomains bool) (*nbdns.NameServerGroup, error) {
CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, _ string, searchDomains bool) (*nbdns.NameServerGroup, error) {
return &nbdns.NameServerGroup{
ID: existingNSGroupID,
Name: name,
@@ -81,16 +80,16 @@ func initNameserversTestData() *NameserversHandler {
SearchDomainsEnabled: searchDomains,
}, nil
},
DeleteNameServerGroupFunc: func(_ context.Context, accountID, nsGroupID, _ string) error {
DeleteNameServerGroupFunc: func(accountID, nsGroupID, _ string) error {
return nil
},
SaveNameServerGroupFunc: func(_ context.Context, accountID, _ string, nsGroupToSave *nbdns.NameServerGroup) error {
SaveNameServerGroupFunc: func(accountID, _ string, nsGroupToSave *nbdns.NameServerGroup) error {
if nsGroupToSave.ID == existingNSGroupID {
return nil
}
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
},
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingNSAccount, testingAccount.Users["test_user"], nil
},
},

View File

@@ -34,22 +34,22 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH
// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
userID := vars["userId"]
if len(userID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
pats, err := h.accountManager.GetAllPATs(r.Context(), account.Id, user.Id, userID)
pats, err := h.accountManager.GetAllPATs(account.Id, user.Id, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
@@ -58,53 +58,53 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
patResponse = append(patResponse, toPATResponse(pat))
}
util.WriteJSONObject(r.Context(), w, patResponse)
util.WriteJSONObject(w, patResponse)
}
// GetToken is HTTP GET handler that returns a personal access token for the given user
func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
tokenID := vars["tokenId"]
if len(tokenID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid token ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid token ID"), w)
return
}
pat, err := h.accountManager.GetPAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
pat, err := h.accountManager.GetPAT(account.Id, user.Id, targetUserID, tokenID)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
util.WriteJSONObject(r.Context(), w, toPATResponse(pat))
util.WriteJSONObject(w, toPATResponse(pat))
}
// CreateToken is HTTP POST handler that creates a personal access token for the given user
func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
@@ -115,44 +115,44 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
return
}
pat, err := h.accountManager.CreatePAT(r.Context(), account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn)
pat, err := h.accountManager.CreatePAT(account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
util.WriteJSONObject(r.Context(), w, toPATGeneratedResponse(pat))
util.WriteJSONObject(w, toPATGeneratedResponse(pat))
}
// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user
func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
tokenID := vars["tokenId"]
if len(tokenID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid token ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid token ID"), w)
return
}
err = h.accountManager.DeletePAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
err = h.accountManager.DeletePAT(account.Id, user.Id, targetUserID, tokenID)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
util.WriteJSONObject(r.Context(), w, emptyObject{})
util.WriteJSONObject(w, emptyObject{})
}
func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken {

View File

@@ -2,7 +2,6 @@ package http
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@@ -64,7 +63,7 @@ var testAccount = &server.Account{
func initPATTestData() *PATHandler {
return &PATHandler{
accountManager: &mock_server.MockAccountManager{
CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
CreatePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}
@@ -77,10 +76,10 @@ func initPATTestData() *PATHandler {
}, nil
},
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testAccount, testAccount.Users[existingUserID], nil
},
DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
DeletePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
if accountID != existingAccountID {
return status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}
@@ -92,7 +91,7 @@ func initPATTestData() *PATHandler {
}
return nil
},
GetPATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
GetPATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}
@@ -104,7 +103,7 @@ func initPATTestData() *PATHandler {
}
return testAccount.Users[existingUserID].PATs[existingTokenID], nil
},
GetAllPATsFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
GetAllPATsFunc: func(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}

View File

@@ -1,7 +1,6 @@
package http
import (
"context"
"encoding/json"
"fmt"
"net/http"
@@ -48,16 +47,16 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error)
return peerToReturn, nil
}
func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) {
peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID)
func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w http.ResponseWriter) {
peer, err := h.accountManager.GetPeer(account.Id, peerID, userID)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(err, w)
return
}
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(err, w)
return
}
dnsDomain := h.accountManager.GetDNSDomain()
@@ -66,19 +65,19 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
validPeers, err := h.accountManager.GetValidatedPeers(account)
if err != nil {
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w)
log.Errorf("failed to list appreoved peers: %v", err)
util.WriteError(fmt.Errorf("internal error"), w)
return
}
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers)
netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validPeers)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID]
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid))
util.WriteJSONObject(w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid))
}
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
req := &api.PeerRequest{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@@ -100,9 +99,9 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
}
}
peer, err := h.accountManager.UpdatePeer(ctx, account.Id, user.Id, update)
peer, err := h.accountManager.UpdatePeer(account.Id, user.Id, update)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(err, w)
return
}
dnsDomain := h.accountManager.GetDNSDomain()
@@ -111,75 +110,75 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
validPeers, err := h.accountManager.GetValidatedPeers(account)
if err != nil {
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w)
log.Errorf("failed to list appreoved peers: %v", err)
util.WriteError(fmt.Errorf("internal error"), w)
return
}
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers)
netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validPeers)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID]
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid))
util.WriteJSONObject(w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid))
}
func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID)
func (h *PeersHandler) deletePeer(accountID, userID string, peerID string, w http.ResponseWriter) {
err := h.accountManager.DeletePeer(accountID, peerID, userID)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete peer: %v", err)
util.WriteError(ctx, err, w)
log.Errorf("failed to delete peer: %v", err)
util.WriteError(err, w)
return
}
util.WriteJSONObject(ctx, w, emptyObject{})
util.WriteJSONObject(w, emptyObject{})
}
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
return
}
switch r.Method {
case http.MethodDelete:
h.deletePeer(r.Context(), account.Id, user.Id, peerID, w)
h.deletePeer(account.Id, user.Id, peerID, w)
return
case http.MethodPut:
h.updatePeer(r.Context(), account, user, peerID, w, r)
h.updatePeer(account, user, peerID, w, r)
return
case http.MethodGet:
h.getPeer(r.Context(), account, peerID, user.Id, w)
h.getPeer(account, peerID, user.Id, w)
return
default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w)
}
}
// GetAllPeers returns a list of all peers associated with a provided account
func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w)
return
}
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
peers, err := h.accountManager.GetPeers(r.Context(), account.Id, user.Id)
peers, err := h.accountManager.GetPeers(account.Id, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
@@ -189,34 +188,34 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
for _, peer := range peers {
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
accessiblePeerNumbers, _ := h.accessiblePeersNumber(r.Context(), account, peer.ID)
accessiblePeerNumbers, _ := h.accessiblePeersNumber(account, peer.ID)
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers))
}
validPeersMap, err := h.accountManager.GetValidatedPeers(account)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
log.Errorf("failed to list appreoved peers: %v", err)
util.WriteError(fmt.Errorf("internal error"), w)
return
}
h.setApprovalRequiredFlag(respBody, validPeersMap)
util.WriteJSONObject(r.Context(), w, respBody)
util.WriteJSONObject(w, respBody)
}
func (h *PeersHandler) accessiblePeersNumber(ctx context.Context, account *server.Account, peerID string) (int, error) {
func (h *PeersHandler) accessiblePeersNumber(account *server.Account, peerID string) (int, error) {
validatedPeersMap, err := h.accountManager.GetValidatedPeers(account)
if err != nil {
return 0, err
}
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validatedPeersMap)
netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validatedPeersMap)
return len(netMap.Peers) + len(netMap.OfflinePeers), nil
}

View File

@@ -2,7 +2,6 @@ package http
import (
"bytes"
"context"
"encoding/json"
"io"
"net"
@@ -30,7 +29,7 @@ const noUpdateChannelTestPeerID = "no-update-channel"
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
return &PeersHandler{
accountManager: &mock_server.MockAccountManager{
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
UpdatePeerFunc: func(accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
var p *nbpeer.Peer
for _, peer := range peers {
if update.ID == peer.ID {
@@ -43,7 +42,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
p.Name = update.Name
return p, nil
},
GetPeerFunc: func(_ context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
GetPeerFunc: func(accountID, peerID, userID string) (*nbpeer.Peer, error) {
var p *nbpeer.Peer
for _, peer := range peers {
if peerID == peer.ID {
@@ -53,13 +52,13 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
}
return p, nil
},
GetPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
GetPeersFunc: func(accountID, userID string) ([]*nbpeer.Peer, error) {
return peers, nil
},
GetDNSDomainFunc: func() string {
return "netbird.selfhosted"
},
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,

View File

@@ -35,15 +35,15 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *
// GetAllPolicies list for the account
func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, user.Id)
accountPolicies, err := h.accountManager.ListPolicies(account.Id, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
@@ -51,28 +51,28 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
for _, policy := range accountPolicies {
resp := toPolicyResponse(account, policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
util.WriteError(status.Errorf(status.Internal, "no rules in the policy"), w)
return
}
policies = append(policies, resp)
}
util.WriteJSONObject(r.Context(), w, policies)
util.WriteJSONObject(w, policies)
}
// UpdatePolicy handles update to a policy identified by a given ID
func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
return
}
@@ -84,7 +84,7 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
}
}
if policyIdx < 0 {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
util.WriteError(status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
return
}
@@ -94,9 +94,9 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
// CreatePolicy handles policy creation request
func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
@@ -118,12 +118,12 @@ func (h *Policies) savePolicy(
}
if req.Name == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy name shouldn't be empty"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "policy name shouldn't be empty"), w)
return
}
if len(req.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy rules shouldn't be empty"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "policy rules shouldn't be empty"), w)
return
}
@@ -137,31 +137,31 @@ func (h *Policies) savePolicy(
Enabled: req.Enabled,
Description: req.Description,
}
for _, rule := range req.Rules {
for _, r := range req.Rules {
pr := server.PolicyRule{
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
Name: rule.Name,
Destinations: groupMinimumsToStrings(account, rule.Destinations),
Sources: groupMinimumsToStrings(account, rule.Sources),
Bidirectional: rule.Bidirectional,
ID: policyID, //TODO: when policy can contain multiple rules, need refactor
Name: r.Name,
Destinations: groupMinimumsToStrings(account, r.Destinations),
Sources: groupMinimumsToStrings(account, r.Sources),
Bidirectional: r.Bidirectional,
}
pr.Enabled = rule.Enabled
if rule.Description != nil {
pr.Description = *rule.Description
pr.Enabled = r.Enabled
if r.Description != nil {
pr.Description = *r.Description
}
switch rule.Action {
switch r.Action {
case api.PolicyRuleUpdateActionAccept:
pr.Action = server.PolicyTrafficActionAccept
case api.PolicyRuleUpdateActionDrop:
pr.Action = server.PolicyTrafficActionDrop
default:
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown action type"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "unknown action type"), w)
return
}
switch rule.Protocol {
switch r.Protocol {
case api.PolicyRuleUpdateProtocolAll:
pr.Protocol = server.PolicyRuleProtocolALL
case api.PolicyRuleUpdateProtocolTcp:
@@ -171,14 +171,14 @@ func (h *Policies) savePolicy(
case api.PolicyRuleUpdateProtocolIcmp:
pr.Protocol = server.PolicyRuleProtocolICMP
default:
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w)
util.WriteError(status.Errorf(status.InvalidArgument, "unknown protocol type: %v", r.Protocol), w)
return
}
if rule.Ports != nil && len(*rule.Ports) != 0 {
for _, v := range *rule.Ports {
if r.Ports != nil && len(*r.Ports) != 0 {
for _, v := range *r.Ports {
if port, err := strconv.Atoi(v); err != nil || port < 1 || port > 65535 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w)
return
}
pr.Ports = append(pr.Ports, v)
@@ -189,16 +189,16 @@ func (h *Policies) savePolicy(
switch pr.Protocol {
case server.PolicyRuleProtocolALL, server.PolicyRuleProtocolICMP:
if len(pr.Ports) != 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w)
return
}
if !pr.Bidirectional {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
return
}
case server.PolicyRuleProtocolTCP, server.PolicyRuleProtocolUDP:
if !pr.Bidirectional && len(pr.Ports) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
return
}
}
@@ -210,26 +210,26 @@ func (h *Policies) savePolicy(
policy.SourcePostureChecks = sourcePostureChecksToStrings(account, *req.SourcePostureChecks)
}
if err := h.accountManager.SavePolicy(r.Context(), account.Id, user.Id, &policy); err != nil {
util.WriteError(r.Context(), err, w)
if err := h.accountManager.SavePolicy(account.Id, user.Id, &policy); err != nil {
util.WriteError(err, w)
return
}
resp := toPolicyResponse(account, &policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
util.WriteError(status.Errorf(status.Internal, "no rules in the policy"), w)
return
}
util.WriteJSONObject(r.Context(), w, resp)
util.WriteJSONObject(w, resp)
}
// DeletePolicy handles policy deletion request
func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
aID := account.Id
@@ -237,24 +237,24 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
return
}
if err = h.accountManager.DeletePolicy(r.Context(), aID, policyID, user.Id); err != nil {
util.WriteError(r.Context(), err, w)
if err = h.accountManager.DeletePolicy(aID, policyID, user.Id); err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(r.Context(), w, emptyObject{})
util.WriteJSONObject(w, emptyObject{})
}
// GetPolicy handles a group Get request identified by ID
func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
@@ -263,25 +263,25 @@ func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
return
}
policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, user.Id)
policy, err := h.accountManager.GetPolicy(account.Id, policyID, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
resp := toPolicyResponse(account, policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
util.WriteError(status.Errorf(status.Internal, "no rules in the policy"), w)
return
}
util.WriteJSONObject(r.Context(), w, resp)
util.WriteJSONObject(w, resp)
default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "method not found"), w)
util.WriteError(status.Errorf(status.NotFound, "method not found"), w)
}
}

View File

@@ -2,7 +2,6 @@ package http
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@@ -31,21 +30,21 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
}
return &Policies{
accountManager: &mock_server.MockAccountManager{
GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*server.Policy, error) {
GetPolicyFunc: func(_, policyID, _ string) (*server.Policy, error) {
policy, ok := testPolicies[policyID]
if !ok {
return nil, status.Errorf(status.NotFound, "policy not found")
}
return policy, nil
},
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) error {
SavePolicyFunc: func(_, _ string, policy *server.Policy) error {
if !strings.HasPrefix(policy.ID, "id-") {
policy.ID = "id-was-set"
policy.Rules[0].ID = "id-was-set"
}
return nil
},
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,

View File

@@ -37,15 +37,15 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa
// GetAllPostureChecks list for the account
func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := p.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
accountPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), account.Id, user.Id)
accountPostureChecks, err := p.accountManager.ListPostureChecks(account.Id, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
@@ -54,22 +54,22 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt
postureChecks = append(postureChecks, postureCheck.ToAPIResponse())
}
util.WriteJSONObject(r.Context(), w, postureChecks)
util.WriteJSONObject(w, postureChecks)
}
// UpdatePostureCheck handles update to a posture check identified by a given ID
func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := p.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
return
}
@@ -81,7 +81,7 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http
}
}
if postureChecksIdx < 0 {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w)
util.WriteError(status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w)
return
}
@@ -91,9 +91,9 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http
// CreatePostureCheck handles posture check creation request
func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := p.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
@@ -103,50 +103,50 @@ func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http
// GetPostureCheck handles a posture check Get request identified by ID
func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := p.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
return
}
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), account.Id, postureChecksID, user.Id)
postureChecks, err := p.accountManager.GetPostureChecks(account.Id, postureChecksID, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse())
util.WriteJSONObject(w, postureChecks.ToAPIResponse())
}
// DeletePostureCheck handles posture check deletion request
func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := p.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
return
}
if err = p.accountManager.DeletePostureChecks(r.Context(), account.Id, postureChecksID, user.Id); err != nil {
util.WriteError(r.Context(), err, w)
if err = p.accountManager.DeletePostureChecks(account.Id, postureChecksID, user.Id); err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(r.Context(), w, emptyObject{})
util.WriteJSONObject(w, emptyObject{})
}
// savePostureChecks handles posture checks create and update
@@ -169,7 +169,7 @@ func (p *PostureChecksHandler) savePostureChecks(
if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil {
if p.geolocationManager == nil {
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
"Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w)
return
}
@@ -177,14 +177,14 @@ func (p *PostureChecksHandler) savePostureChecks(
postureChecks, err := posture.NewChecksFromAPIPostureCheckUpdate(req, postureChecksID)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
if err := p.accountManager.SavePostureChecks(r.Context(), account.Id, user.Id, postureChecks); err != nil {
util.WriteError(r.Context(), err, w)
if err := p.accountManager.SavePostureChecks(account.Id, user.Id, postureChecks); err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse())
util.WriteJSONObject(w, postureChecks.ToAPIResponse())
}

View File

@@ -2,7 +2,6 @@ package http
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@@ -34,14 +33,14 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
return &PostureChecksHandler{
accountManager: &mock_server.MockAccountManager{
GetPostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
GetPostureChecksFunc: func(accountID, postureChecksID, userID string) (*posture.Checks, error) {
p, ok := testPostureChecks[postureChecksID]
if !ok {
return nil, status.Errorf(status.NotFound, "posture checks not found")
}
return p, nil
},
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error {
SavePostureChecksFunc: func(accountID, userID string, postureChecks *posture.Checks) error {
postureChecks.ID = "postureCheck"
testPostureChecks[postureChecks.ID] = postureChecks
@@ -51,7 +50,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
return nil
},
DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error {
DeletePostureChecksFunc: func(accountID, postureChecksID, userID string) error {
_, ok := testPostureChecks[postureChecksID]
if !ok {
return status.Errorf(status.NotFound, "posture checks not found")
@@ -60,14 +59,14 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
return nil
},
ListPostureChecksFunc: func(_ context.Context, accountID, userID string) ([]*posture.Checks, error) {
ListPostureChecksFunc: func(accountID, userID string) ([]*posture.Checks, error) {
accountPostureChecks := make([]*posture.Checks, len(testPostureChecks))
for _, p := range testPostureChecks {
accountPostureChecks = append(accountPostureChecks, p)
}
return accountPostureChecks, nil
},
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,

View File

@@ -43,36 +43,36 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro
// GetAllRoutes returns the list of routes for the account
func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
routes, err := h.accountManager.ListRoutes(r.Context(), account.Id, user.Id)
routes, err := h.accountManager.ListRoutes(account.Id, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
apiRoutes := make([]*api.Route, 0)
for _, route := range routes {
route, err := toRouteResponse(route)
for _, r := range routes {
route, err := toRouteResponse(r)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w)
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w)
return
}
apiRoutes = append(apiRoutes, route)
}
util.WriteJSONObject(r.Context(), w, apiRoutes)
util.WriteJSONObject(w, apiRoutes)
}
// CreateRoute handles route creation request
func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
@@ -84,7 +84,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
}
if err := h.validateRoute(req); err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
@@ -94,7 +94,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
if req.Domains != nil {
d, err := validateDomains(*req.Domains)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
return
}
domains = d
@@ -102,7 +102,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
} else if req.Network != nil {
networkType, newPrefix, err = route.ParseNetwork(*req.Network)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
}
@@ -120,24 +120,24 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
// Do not allow non-Linux peers
if peer := account.GetPeer(peerId); peer != nil {
if peer.Meta.GoOS != "linux" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w)
return
}
}
newRoute, err := h.accountManager.CreateRoute(r.Context(), account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute)
newRoute, err := h.accountManager.CreateRoute(account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
routes, err := toRouteResponse(newRoute)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w)
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w)
return
}
util.WriteJSONObject(r.Context(), w, routes)
util.WriteJSONObject(w, routes)
}
func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) error {
@@ -168,22 +168,22 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
// UpdateRoute handles update to a route identified by a given ID
func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
routeID := vars["routeId"]
if len(routeID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return
}
_, err = h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
_, err = h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
@@ -195,7 +195,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
}
if err := h.validateRoute(req); err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
@@ -207,7 +207,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
// do not allow non Linux peers
if peer := account.GetPeer(peerID); peer != nil {
if peer.Meta.GoOS != "linux" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w)
return
}
}
@@ -226,7 +226,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
if req.Domains != nil {
d, err := validateDomains(*req.Domains)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
return
}
newRoute.Domains = d
@@ -234,7 +234,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
} else if req.Network != nil {
newRoute.NetworkType, newRoute.Network, err = route.ParseNetwork(*req.Network)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
}
@@ -247,73 +247,73 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
newRoute.PeerGroups = *req.PeerGroups
}
err = h.accountManager.SaveRoute(r.Context(), account.Id, user.Id, newRoute)
err = h.accountManager.SaveRoute(account.Id, user.Id, newRoute)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
routes, err := toRouteResponse(newRoute)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w)
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w)
return
}
util.WriteJSONObject(r.Context(), w, routes)
util.WriteJSONObject(w, routes)
}
// DeleteRoute handles route deletion request
func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
routeID := mux.Vars(r)["routeId"]
if len(routeID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return
}
err = h.accountManager.DeleteRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
err = h.accountManager.DeleteRoute(account.Id, route.ID(routeID), user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
util.WriteJSONObject(r.Context(), w, emptyObject{})
util.WriteJSONObject(w, emptyObject{})
}
// GetRoute handles a route Get request identified by ID
func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
routeID := mux.Vars(r)["routeId"]
if len(routeID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return
}
foundRoute, err := h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
foundRoute, err := h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w)
util.WriteError(status.Errorf(status.NotFound, "route not found"), w)
return
}
routes, err := toRouteResponse(foundRoute)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w)
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w)
return
}
util.WriteJSONObject(r.Context(), w, routes)
util.WriteJSONObject(w, routes)
}
func toRouteResponse(serverRoute *route.Route) (*api.Route, error) {

View File

@@ -2,7 +2,6 @@ package http
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@@ -90,7 +89,7 @@ var testingAccount = &server.Account{
func initRoutesTestData() *RoutesHandler {
return &RoutesHandler{
accountManager: &mock_server.MockAccountManager{
GetRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) (*route.Route, error) {
GetRouteFunc: func(_ string, routeID route.ID, _ string) (*route.Route, error) {
if routeID == existingRouteID {
return baseExistingRoute, nil
}
@@ -105,7 +104,7 @@ func initRoutesTestData() *RoutesHandler {
}
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
},
CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) {
CreateRouteFunc: func(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) {
if peerID == notFoundPeerID {
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
}
@@ -127,19 +126,19 @@ func initRoutesTestData() *RoutesHandler {
KeepRoute: keepRoute,
}, nil
},
SaveRouteFunc: func(_ context.Context, _, _ string, r *route.Route) error {
SaveRouteFunc: func(_, _ string, r *route.Route) error {
if r.Peer == notFoundPeerID {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer)
}
return nil
},
DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error {
DeleteRouteFunc: func(_ string, routeID route.ID, _ string) error {
if routeID != existingRouteID {
return status.Errorf(status.NotFound, "Peer with ID %s not found", routeID)
}
return nil
},
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingAccount, testingAccount.Users["test_user"], nil
},
},

View File

@@ -1,7 +1,6 @@
package http
import (
"context"
"encoding/json"
"net/http"
"time"
@@ -35,9 +34,9 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg)
// CreateSetupKey is a POST requests that creates a new SetupKey
func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
@@ -49,13 +48,13 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
}
if req.Name == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name shouldn't be empty"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "setup key name shouldn't be empty"), w)
return
}
if !(server.SetupKeyType(req.Type) == server.SetupKeyReusable ||
server.SetupKeyType(req.Type) == server.SetupKeyOneOff) {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown setup key type %s", req.Type), w)
util.WriteError(status.Errorf(status.InvalidArgument, "unknown setup key type %s", req.Type), w)
return
}
@@ -64,7 +63,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
day := time.Hour * 24
year := day * 365
if expiresIn < day || expiresIn > year {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "expiresIn should be between 1 day and 365 days"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "expiresIn should be between 1 day and 365 days"), w)
return
}
@@ -76,54 +75,54 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
if req.Ephemeral != nil {
ephemeral = *req.Ephemeral
}
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
setupKey, err := h.accountManager.CreateSetupKey(account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
req.AutoGroups, req.UsageLimit, user.Id, ephemeral)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
writeSuccess(r.Context(), w, setupKey)
writeSuccess(w, setupKey)
}
// GetSetupKey is a GET request to get a SetupKey by ID
func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
keyID := vars["keyId"]
if len(keyID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w)
return
}
key, err := h.accountManager.GetSetupKey(r.Context(), account.Id, user.Id, keyID)
key, err := h.accountManager.GetSetupKey(account.Id, user.Id, keyID)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
writeSuccess(r.Context(), w, key)
writeSuccess(w, key)
}
// UpdateSetupKey is a PUT request to update server.SetupKey
func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
keyID := vars["keyId"]
if len(keyID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w)
return
}
@@ -135,12 +134,12 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
}
if req.Name == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
util.WriteError(status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
return
}
if req.AutoGroups == nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
return
}
@@ -150,26 +149,26 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
newKey.Name = req.Name
newKey.Id = keyID
newKey, err = h.accountManager.SaveSetupKey(r.Context(), account.Id, newKey, user.Id)
newKey, err = h.accountManager.SaveSetupKey(account.Id, newKey, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
writeSuccess(r.Context(), w, newKey)
writeSuccess(w, newKey)
}
// GetAllSetupKeys is a GET request that returns a list of SetupKey
func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), account.Id, user.Id)
setupKeys, err := h.accountManager.ListSetupKeys(account.Id, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
@@ -178,15 +177,15 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques
apiSetupKeys = append(apiSetupKeys, toResponseBody(key))
}
util.WriteJSONObject(r.Context(), w, apiSetupKeys)
util.WriteJSONObject(w, apiSetupKeys)
}
func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) {
func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
err := json.NewEncoder(w).Encode(toResponseBody(key))
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(err, w)
return
}
}

View File

@@ -2,7 +2,6 @@ package http
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@@ -34,7 +33,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
) *SetupKeysHandler {
return &SetupKeysHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{
Id: testAccountID,
Domain: "hotmail.com",
@@ -50,7 +49,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
},
}, user, nil
},
CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
_ int, _ string, ephemeral bool,
) (*server.SetupKey, error) {
if keyName == newKey.Name || typ != newKey.Type {
@@ -60,7 +59,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
}
return nil, fmt.Errorf("failed creating setup key")
},
GetSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) (*server.SetupKey, error) {
GetSetupKeyFunc: func(accountID, userID, keyID string) (*server.SetupKey, error) {
switch keyID {
case defaultKey.Id:
return defaultKey, nil
@@ -71,14 +70,14 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
}
},
SaveSetupKeyFunc: func(_ context.Context, accountID string, key *server.SetupKey, _ string) (*server.SetupKey, error) {
SaveSetupKeyFunc: func(accountID string, key *server.SetupKey, _ string) (*server.SetupKey, error) {
if key.Id == updatedSetupKey.Id {
return updatedSetupKey, nil
}
return nil, status.Errorf(status.NotFound, "key %s not found", key.Id)
},
ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*server.SetupKey, error) {
ListSetupKeysFunc: func(accountID, userID string) ([]*server.SetupKey, error) {
return []*server.SetupKey{defaultKey}, nil
},
},

View File

@@ -41,22 +41,22 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
userID := vars["userId"]
if len(userID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
existingUser, ok := account.Users[userID]
if !ok {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w)
util.WriteError(status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w)
return
}
@@ -74,11 +74,11 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
userRole := server.StrRoleToUserRole(req.Role)
if userRole == server.UserRoleUnknown {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user role"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user role"), w)
return
}
newUser, err := h.accountManager.SaveUser(r.Context(), account.Id, user.Id, &server.User{
newUser, err := h.accountManager.SaveUser(account.Id, user.Id, &server.User{
Id: userID,
Role: userRole,
AutoGroups: req.AutoGroups,
@@ -88,10 +88,10 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
})
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId))
}
// DeleteUser is a DELETE request to delete a user
@@ -102,26 +102,26 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
err = h.accountManager.DeleteUser(r.Context(), account.Id, user.Id, targetUserID)
err = h.accountManager.DeleteUser(account.Id, user.Id, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
util.WriteJSONObject(r.Context(), w, emptyObject{})
util.WriteJSONObject(w, emptyObject{})
}
// CreateUser creates a User in the system with a status "invited" (effectively this is a user invite).
@@ -132,9 +132,9 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
@@ -146,7 +146,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
}
if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w)
util.WriteError(status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w)
return
}
@@ -160,7 +160,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
name = *req.Name
}
newUser, err := h.accountManager.CreateUser(r.Context(), account.Id, user.Id, &server.UserInfo{
newUser, err := h.accountManager.CreateUser(account.Id, user.Id, &server.UserInfo{
Email: email,
Name: name,
Role: req.Role,
@@ -169,10 +169,10 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
Issued: server.UserIssuedAPI,
})
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId))
}
// GetAllUsers returns a list of users of the account this user belongs to.
@@ -184,42 +184,42 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
data, err := h.accountManager.GetUsersFromAccount(r.Context(), account.Id, user.Id)
data, err := h.accountManager.GetUsersFromAccount(account.Id, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
serviceUser := r.URL.Query().Get("service_user")
users := make([]*api.User, 0)
for _, d := range data {
if d.NonDeletable {
for _, r := range data {
if r.NonDeletable {
continue
}
if serviceUser == "" {
users = append(users, toUserResponse(d, claims.UserId))
users = append(users, toUserResponse(r, claims.UserId))
continue
}
includeServiceUser, err := strconv.ParseBool(serviceUser)
log.WithContext(r.Context()).Debugf("Should include service user: %v", includeServiceUser)
log.Debugf("Should include service user: %v", includeServiceUser)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w)
return
}
if includeServiceUser == d.IsServiceUser {
users = append(users, toUserResponse(d, claims.UserId))
if includeServiceUser == r.IsServiceUser {
users = append(users, toUserResponse(r, claims.UserId))
}
}
util.WriteJSONObject(r.Context(), w, users)
util.WriteJSONObject(w, users)
}
// InviteUser resend invitations to users who haven't activated their accounts,
@@ -231,26 +231,26 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
err = h.accountManager.InviteUser(r.Context(), account.Id, user.Id, targetUserID)
err = h.accountManager.InviteUser(account.Id, user.Id, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(err, w)
return
}
util.WriteJSONObject(r.Context(), w, emptyObject{})
util.WriteJSONObject(w, emptyObject{})
}
func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {

View File

@@ -2,7 +2,6 @@ package http
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@@ -64,10 +63,10 @@ var usersTestAccount = &server.Account{
func initUsersTestData() *UsersHandler {
return &UsersHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return usersTestAccount, usersTestAccount.Users[claims.UserId], nil
},
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
users := make([]*server.UserInfo, 0)
for _, v := range usersTestAccount.Users {
users = append(users, &server.UserInfo{
@@ -82,13 +81,13 @@ func initUsersTestData() *UsersHandler {
}
return users, nil
},
CreateUserFunc: func(_ context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) {
CreateUserFunc: func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) {
if userID != existingUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
}
return key, nil
},
DeleteUserFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) error {
DeleteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error {
if targetUserID == notFoundUserID {
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
}
@@ -97,7 +96,7 @@ func initUsersTestData() *UsersHandler {
}
return nil
},
SaveUserFunc: func(_ context.Context, accountID, userID string, update *server.User) (*server.UserInfo, error) {
SaveUserFunc: func(accountID, userID string, update *server.User) (*server.UserInfo, error) {
if update.Id == notFoundUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", update.Id)
}
@@ -112,7 +111,7 @@ func initUsersTestData() *UsersHandler {
}
return info, nil
},
InviteUserFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) error {
InviteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error {
if initiatorUserID != existingUserID {
return status.Errorf(status.NotFound, "user with ID %s does not exists", initiatorUserID)
}

View File

@@ -1,7 +1,6 @@
package util
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -20,12 +19,12 @@ type ErrorResponse struct {
}
// WriteJSONObject simply writes object to the HTTP response in JSON format
func WriteJSONObject(ctx context.Context, w http.ResponseWriter, obj interface{}) {
func WriteJSONObject(w http.ResponseWriter, obj interface{}) {
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
w.WriteHeader(http.StatusOK)
err := json.NewEncoder(w).Encode(obj)
if err != nil {
WriteError(ctx, err, w)
WriteError(err, w)
return
}
}
@@ -77,8 +76,8 @@ func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) {
// WriteError converts an error to an JSON error response.
// If it is known internal error of type server.Error then it sets the messages from the error, a generic message otherwise
func WriteError(ctx context.Context, err error, w http.ResponseWriter) {
log.WithContext(ctx).Errorf("got a handler error: %s", err.Error())
func WriteError(err error, w http.ResponseWriter) {
log.Errorf("got a handler error: %s", err.Error())
errStatus, ok := status.FromError(err)
httpStatus := http.StatusInternalServerError
msg := "internal server error"
@@ -107,7 +106,7 @@ func WriteError(ctx context.Context, err error, w http.ResponseWriter) {
msg = strings.ToLower(err.Error())
} else {
unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", err.Error())
log.WithContext(ctx).Error(unhandledMSG)
log.Error(unhandledMSG)
}
WriteErrorResponse(msg, httpStatus, w)

View File

@@ -183,7 +183,7 @@ func (c *Auth0Credentials) jwtStillValid() bool {
}
// requestJWTToken performs request to get jwt token
func (c *Auth0Credentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
func (c *Auth0Credentials) requestJWTToken() (*http.Response, error) {
var res *http.Response
reqURL := c.clientConfig.AuthIssuer + "/oauth/token"
@@ -200,7 +200,7 @@ func (c *Auth0Credentials) requestJWTToken(ctx context.Context) (*http.Response,
req.Header.Add("content-type", "application/json")
log.WithContext(ctx).Debug("requesting new jwt token for idp manager")
log.Debug("requesting new jwt token for idp manager")
res, err = c.httpClient.Do(req)
if err != nil {
@@ -247,7 +247,7 @@ func (c *Auth0Credentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTTo
}
// Authenticate retrieves access token to use the Auth0 Management API
func (c *Auth0Credentials) Authenticate(ctx context.Context) (JWTToken, error) {
func (c *Auth0Credentials) Authenticate() (JWTToken, error) {
c.mux.Lock()
defer c.mux.Unlock()
@@ -260,14 +260,14 @@ func (c *Auth0Credentials) Authenticate(ctx context.Context) (JWTToken, error) {
return c.jwtToken, nil
}
res, err := c.requestJWTToken(ctx)
res, err := c.requestJWTToken()
if err != nil {
return c.jwtToken, err
}
defer func() {
err = res.Body.Close()
if err != nil {
log.WithContext(ctx).Errorf("error while closing get jwt token response body: %v", err)
log.Errorf("error while closing get jwt token response body: %v", err)
}
}()
@@ -301,8 +301,8 @@ func requestByUserIDURL(authIssuer, userID string) string {
}
// GetAccount returns all the users for a given profile. Calls Auth0 API.
func (am *Auth0Manager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
jwtToken, err := am.credentials.Authenticate(ctx)
func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return nil, err
}
@@ -353,7 +353,7 @@ func (am *Auth0Manager) GetAccount(ctx context.Context, accountID string) ([]*Us
return nil, err
}
log.WithContext(ctx).Debugf("returned user batch for accountID %s on page %d, batch length %d", accountID, page, len(batch))
log.Debugf("returned user batch for accountID %s on page %d, batch length %d", accountID, page, len(batch))
err = res.Body.Close()
if err != nil {
@@ -365,7 +365,7 @@ func (am *Auth0Manager) GetAccount(ctx context.Context, accountID string) ([]*Us
}
if len(batch) == 0 || len(batch) < resultsPerPage {
log.WithContext(ctx).Debugf("finished loading users for accountID %s", accountID)
log.Debugf("finished loading users for accountID %s", accountID)
return list, nil
}
}
@@ -374,8 +374,8 @@ func (am *Auth0Manager) GetAccount(ctx context.Context, accountID string) ([]*Us
}
// GetUserDataByID requests user data from auth0 via ID
func (am *Auth0Manager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
jwtToken, err := am.credentials.Authenticate(ctx)
func (am *Auth0Manager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return nil, err
}
@@ -414,7 +414,7 @@ func (am *Auth0Manager) GetUserDataByID(ctx context.Context, userID string, appM
defer func() {
err = res.Body.Close()
if err != nil {
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
log.Errorf("error while closing update user app metadata response body: %v", err)
}
}()
@@ -426,9 +426,9 @@ func (am *Auth0Manager) GetUserDataByID(ctx context.Context, userID string, appM
}
// UpdateUserAppMetadata updates user app metadata based on userId and metadata map
func (am *Auth0Manager) UpdateUserAppMetadata(ctx context.Context, userID string, appMetadata AppMetadata) error {
func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
jwtToken, err := am.credentials.Authenticate(ctx)
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return err
}
@@ -449,7 +449,7 @@ func (am *Auth0Manager) UpdateUserAppMetadata(ctx context.Context, userID string
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
log.WithContext(ctx).Debugf("updating IdP metadata for user %s", userID)
log.Debugf("updating IdP metadata for user %s", userID)
res, err := am.httpClient.Do(req)
if err != nil {
@@ -466,7 +466,7 @@ func (am *Auth0Manager) UpdateUserAppMetadata(ctx context.Context, userID string
defer func() {
err = res.Body.Close()
if err != nil {
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
log.Errorf("error while closing update user app metadata response body: %v", err)
}
}()
@@ -530,9 +530,9 @@ func buildUserExportRequest() (string, error) {
}
func (am *Auth0Manager) createRequest(
ctx context.Context, method string, endpoint string, body io.Reader,
method string, endpoint string, body io.Reader,
) (*http.Request, error) {
jwtToken, err := am.credentials.Authenticate(ctx)
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return nil, err
}
@@ -548,8 +548,8 @@ func (am *Auth0Manager) createRequest(
return req, nil
}
func (am *Auth0Manager) createPostRequest(ctx context.Context, endpoint string, payloadStr string) (*http.Request, error) {
req, err := am.createRequest(ctx, "POST", endpoint, strings.NewReader(payloadStr))
func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) {
req, err := am.createRequest("POST", endpoint, strings.NewReader(payloadStr))
if err != nil {
return nil, err
}
@@ -560,20 +560,20 @@ func (am *Auth0Manager) createPostRequest(ctx context.Context, endpoint string,
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (am *Auth0Manager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
payloadString, err := buildUserExportRequest()
if err != nil {
return nil, err
}
exportJobReq, err := am.createPostRequest(ctx, "/api/v2/jobs/users-exports", payloadString)
exportJobReq, err := am.createPostRequest("/api/v2/jobs/users-exports", payloadString)
if err != nil {
return nil, err
}
jobResp, err := am.httpClient.Do(exportJobReq)
if err != nil {
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
log.Debugf("Couldn't get job response %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
@@ -583,7 +583,7 @@ func (am *Auth0Manager) GetAllAccounts(ctx context.Context) (map[string][]*UserD
defer func() {
err = jobResp.Body.Close()
if err != nil {
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
log.Errorf("error while closing update user app metadata response body: %v", err)
}
}()
if jobResp.StatusCode != 200 {
@@ -597,13 +597,13 @@ func (am *Auth0Manager) GetAllAccounts(ctx context.Context) (map[string][]*UserD
body, err := io.ReadAll(jobResp.Body)
if err != nil {
log.WithContext(ctx).Debugf("Couldn't read export job response; %v", err)
log.Debugf("Couldn't read export job response; %v", err)
return nil, err
}
err = am.helper.Unmarshal(body, &exportJobResp)
if err != nil {
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
log.Debugf("Couldn't unmarshal export job response; %v", err)
return nil, err
}
@@ -614,16 +614,16 @@ func (am *Auth0Manager) GetAllAccounts(ctx context.Context) (map[string][]*UserD
return nil, fmt.Errorf("couldn't get an batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
}
log.WithContext(ctx).Debugf("batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
log.Debugf("batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
done, downloadLink, err := am.checkExportJobStatus(ctx, exportJobResp.ID)
done, downloadLink, err := am.checkExportJobStatus(exportJobResp.ID)
if err != nil {
log.WithContext(ctx).Debugf("Failed at getting status checks from exportJob; %v", err)
log.Debugf("Failed at getting status checks from exportJob; %v", err)
return nil, err
}
if done {
return am.downloadProfileExport(ctx, downloadLink)
return am.downloadProfileExport(downloadLink)
}
return nil, fmt.Errorf("failed extracting user profiles from auth0")
@@ -632,13 +632,13 @@ func (am *Auth0Manager) GetAllAccounts(ctx context.Context) (map[string][]*UserD
// GetUserByEmail searches users with a given email. If no users have been found, this function returns an empty list.
// This function can return multiple users. This is due to the Auth0 internals - there could be multiple users with
// the same email but different connections that are considered as separate accounts (e.g., Google and username/password).
func (am *Auth0Manager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
jwtToken, err := am.credentials.Authenticate(ctx)
func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return nil, err
}
reqURL := am.authIssuer + "/api/v2/users-by-email?email=" + url.QueryEscape(email)
body, err := doGetReq(ctx, am.httpClient, reqURL, jwtToken.AccessToken)
body, err := doGetReq(am.httpClient, reqURL, jwtToken.AccessToken)
if err != nil {
return nil, err
}
@@ -651,7 +651,7 @@ func (am *Auth0Manager) GetUserByEmail(ctx context.Context, email string) ([]*Us
err = am.helper.Unmarshal(body, &userResp)
if err != nil {
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
log.Debugf("Couldn't unmarshal export job response; %v", err)
return nil, err
}
@@ -659,13 +659,13 @@ func (am *Auth0Manager) GetUserByEmail(ctx context.Context, email string) ([]*Us
}
// CreateUser creates a new user in Auth0 Idp and sends an invite
func (am *Auth0Manager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
payloadString, err := buildCreateUserRequestPayload(email, name, accountID, invitedByEmail)
if err != nil {
return nil, err
}
req, err := am.createPostRequest(ctx, "/api/v2/users", payloadString)
req, err := am.createPostRequest("/api/v2/users", payloadString)
if err != nil {
return nil, err
}
@@ -676,7 +676,7 @@ func (am *Auth0Manager) CreateUser(ctx context.Context, email, name, accountID,
resp, err := am.httpClient.Do(req)
if err != nil {
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
log.Debugf("Couldn't get job response %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
@@ -686,7 +686,7 @@ func (am *Auth0Manager) CreateUser(ctx context.Context, email, name, accountID,
defer func() {
err = resp.Body.Close()
if err != nil {
log.WithContext(ctx).Errorf("error while closing create user response body: %v", err)
log.Errorf("error while closing create user response body: %v", err)
}
}()
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
@@ -700,13 +700,13 @@ func (am *Auth0Manager) CreateUser(ctx context.Context, email, name, accountID,
body, err := io.ReadAll(resp.Body)
if err != nil {
log.WithContext(ctx).Debugf("Couldn't read export job response; %v", err)
log.Debugf("Couldn't read export job response; %v", err)
return nil, err
}
err = am.helper.Unmarshal(body, &createResp)
if err != nil {
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
log.Debugf("Couldn't unmarshal export job response; %v", err)
return nil, err
}
@@ -714,14 +714,14 @@ func (am *Auth0Manager) CreateUser(ctx context.Context, email, name, accountID,
return nil, fmt.Errorf("couldn't create user: response %v", resp)
}
log.WithContext(ctx).Debugf("created user %s in account %s", createResp.ID, accountID)
log.Debugf("created user %s in account %s", createResp.ID, accountID)
return &createResp, nil
}
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (am *Auth0Manager) InviteUserByID(ctx context.Context, userID string) error {
func (am *Auth0Manager) InviteUserByID(userID string) error {
userVerificationReq := userVerificationJobRequest{
UserID: userID,
}
@@ -731,14 +731,14 @@ func (am *Auth0Manager) InviteUserByID(ctx context.Context, userID string) error
return err
}
req, err := am.createPostRequest(ctx, "/api/v2/jobs/verification-email", string(payload))
req, err := am.createPostRequest("/api/v2/jobs/verification-email", string(payload))
if err != nil {
return err
}
resp, err := am.httpClient.Do(req)
if err != nil {
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
log.Debugf("Couldn't get job response %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
@@ -748,7 +748,7 @@ func (am *Auth0Manager) InviteUserByID(ctx context.Context, userID string) error
defer func() {
err = resp.Body.Close()
if err != nil {
log.WithContext(ctx).Errorf("error while closing invite user response body: %v", err)
log.Errorf("error while closing invite user response body: %v", err)
}
}()
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
@@ -762,15 +762,15 @@ func (am *Auth0Manager) InviteUserByID(ctx context.Context, userID string) error
}
// DeleteUser from Auth0
func (am *Auth0Manager) DeleteUser(ctx context.Context, userID string) error {
req, err := am.createRequest(ctx, http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil)
func (am *Auth0Manager) DeleteUser(userID string) error {
req, err := am.createRequest(http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil)
if err != nil {
return err
}
resp, err := am.httpClient.Do(req)
if err != nil {
log.WithContext(ctx).Debugf("execute delete request: %v", err)
log.Debugf("execute delete request: %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
@@ -780,7 +780,7 @@ func (am *Auth0Manager) DeleteUser(ctx context.Context, userID string) error {
defer func() {
err = resp.Body.Close()
if err != nil {
log.WithContext(ctx).Errorf("close delete request body: %v", err)
log.Errorf("close delete request body: %v", err)
}
}()
if resp.StatusCode != 204 {
@@ -795,20 +795,20 @@ func (am *Auth0Manager) DeleteUser(ctx context.Context, userID string) error {
// GetAllConnections returns detailed list of all connections filtered by given params.
// Note this method is not part of the IDP Manager interface as this is Auth0 specific.
func (am *Auth0Manager) GetAllConnections(ctx context.Context, strategy []string) ([]Connection, error) {
func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, error) {
var connections []Connection
q := make(url.Values)
q.Set("strategy", strings.Join(strategy, ","))
req, err := am.createRequest(ctx, http.MethodGet, "/api/v2/connections?"+q.Encode(), nil)
req, err := am.createRequest(http.MethodGet, "/api/v2/connections?"+q.Encode(), nil)
if err != nil {
return connections, err
}
resp, err := am.httpClient.Do(req)
if err != nil {
log.WithContext(ctx).Debugf("execute get connections request: %v", err)
log.Debugf("execute get connections request: %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
@@ -818,7 +818,7 @@ func (am *Auth0Manager) GetAllConnections(ctx context.Context, strategy []string
defer func() {
err = resp.Body.Close()
if err != nil {
log.WithContext(ctx).Errorf("close get connections request body: %v", err)
log.Errorf("close get connections request body: %v", err)
}
}()
if resp.StatusCode != 200 {
@@ -830,13 +830,13 @@ func (am *Auth0Manager) GetAllConnections(ctx context.Context, strategy []string
body, err := io.ReadAll(resp.Body)
if err != nil {
log.WithContext(ctx).Debugf("Couldn't read get connections response; %v", err)
log.Debugf("Couldn't read get connections response; %v", err)
return connections, err
}
err = am.helper.Unmarshal(body, &connections)
if err != nil {
log.WithContext(ctx).Debugf("Couldn't unmarshal get connection response; %v", err)
log.Debugf("Couldn't unmarshal get connection response; %v", err)
return connections, err
}
@@ -845,23 +845,23 @@ func (am *Auth0Manager) GetAllConnections(ctx context.Context, strategy []string
// checkExportJobStatus checks the status of the job created at CreateExportUsersJob.
// If the status is "completed", then return the downloadLink
func (am *Auth0Manager) checkExportJobStatus(ctx context.Context, jobID string) (bool, string, error) {
ctx, cancel := context.WithTimeout(ctx, 90*time.Second)
func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
defer cancel()
retry := time.NewTicker(10 * time.Second)
for {
select {
case <-ctx.Done():
log.WithContext(ctx).Debugf("Export job status stopped...\n")
log.Debugf("Export job status stopped...\n")
return false, "", ctx.Err()
case <-retry.C:
jwtToken, err := am.credentials.Authenticate(ctx)
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return false, "", err
}
statusURL := am.authIssuer + "/api/v2/jobs/" + jobID
body, err := doGetReq(ctx, am.httpClient, statusURL, jwtToken.AccessToken)
body, err := doGetReq(am.httpClient, statusURL, jwtToken.AccessToken)
if err != nil {
return false, "", err
}
@@ -872,7 +872,7 @@ func (am *Auth0Manager) checkExportJobStatus(ctx context.Context, jobID string)
return false, "", err
}
log.WithContext(ctx).Debugf("current export job status is %v", status.Status)
log.Debugf("current export job status is %v", status.Status)
if status.Status != "completed" {
continue
@@ -884,8 +884,8 @@ func (am *Auth0Manager) checkExportJobStatus(ctx context.Context, jobID string)
}
// downloadProfileExport downloads user profiles from auth0 batch job
func (am *Auth0Manager) downloadProfileExport(ctx context.Context, location string) (map[string][]*UserData, error) {
body, err := doGetReq(ctx, am.httpClient, location, "")
func (am *Auth0Manager) downloadProfileExport(location string) (map[string][]*UserData, error) {
body, err := doGetReq(am.httpClient, location, "")
if err != nil {
return nil, err
}
@@ -927,7 +927,7 @@ func (am *Auth0Manager) downloadProfileExport(ctx context.Context, location stri
}
// Boilerplate implementation for Get Requests.
func doGetReq(ctx context.Context, client ManagerHTTPClient, url, accessToken string) ([]byte, error) {
func doGetReq(client ManagerHTTPClient, url, accessToken string) ([]byte, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
@@ -945,7 +945,7 @@ func doGetReq(ctx context.Context, client ManagerHTTPClient, url, accessToken st
defer func() {
err = res.Body.Close()
if err != nil {
log.WithContext(ctx).Errorf("error while closing body for url %s: %v", url, err)
log.Errorf("error while closing body for url %s: %v", url, err)
}
}()
body, err := io.ReadAll(res.Body)

View File

@@ -1,7 +1,6 @@
package idp
import (
"context"
"encoding/json"
"fmt"
"io"
@@ -61,7 +60,7 @@ type mockAuth0Credentials struct {
err error
}
func (mc *mockAuth0Credentials) Authenticate(_ context.Context) (JWTToken, error) {
func (mc *mockAuth0Credentials) Authenticate() (JWTToken, error) {
return mc.jwtToken, mc.err
}
@@ -127,7 +126,7 @@ func TestAuth0_RequestJWTToken(t *testing.T) {
helper: testCase.helper,
}
res, err := creds.requestJWTToken(context.Background())
res, err := creds.requestJWTToken()
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
@@ -296,7 +295,7 @@ func TestAuth0_Authenticate(t *testing.T) {
creds.jwtToken.expiresInTime = testCase.inputExpireToken
_, err := creds.Authenticate(context.Background())
_, err := creds.Authenticate()
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
@@ -418,7 +417,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
helper: testCase.helper,
}
err := manager.UpdateUserAppMetadata(context.Background(), "1", testCase.appMetadata)
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
assert.Equal(t, testCase.expectedReqBody, jwtReqClient.reqBody, "request body should match")

View File

@@ -116,7 +116,7 @@ func (ac *AuthentikCredentials) jwtStillValid() bool {
}
// requestJWTToken performs request to get jwt token.
func (ac *AuthentikCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
func (ac *AuthentikCredentials) requestJWTToken() (*http.Response, error) {
data := url.Values{}
data.Set("client_id", ac.clientConfig.ClientID)
data.Set("username", ac.clientConfig.Username)
@@ -131,7 +131,7 @@ func (ac *AuthentikCredentials) requestJWTToken(ctx context.Context) (*http.Resp
}
req.Header.Add("content-type", "application/x-www-form-urlencoded")
log.WithContext(ctx).Debug("requesting new jwt token for authentik idp manager")
log.Debug("requesting new jwt token for authentik idp manager")
resp, err := ac.httpClient.Do(req)
if err != nil {
@@ -183,7 +183,7 @@ func (ac *AuthentikCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (
}
// Authenticate retrieves access token to use the authentik management API.
func (ac *AuthentikCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) {
ac.mux.Lock()
defer ac.mux.Unlock()
@@ -197,7 +197,7 @@ func (ac *AuthentikCredentials) Authenticate(ctx context.Context) (JWTToken, err
return ac.jwtToken, nil
}
resp, err := ac.requestJWTToken(ctx)
resp, err := ac.requestJWTToken()
if err != nil {
return ac.jwtToken, err
}
@@ -214,13 +214,13 @@ func (ac *AuthentikCredentials) Authenticate(ctx context.Context) (JWTToken, err
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (am *AuthentikManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
func (am *AuthentikManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
return nil
}
// GetUserDataByID requests user data from authentik via ID.
func (am *AuthentikManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
ctx, err := am.authenticationContext(ctx)
func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
ctx, err := am.authenticationContext()
if err != nil {
return nil, err
}
@@ -254,8 +254,8 @@ func (am *AuthentikManager) GetUserDataByID(ctx context.Context, userID string,
}
// GetAccount returns all the users for a given profile.
func (am *AuthentikManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
users, err := am.getAllUsers(ctx)
func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) {
users, err := am.getAllUsers()
if err != nil {
return nil, err
}
@@ -274,8 +274,8 @@ func (am *AuthentikManager) GetAccount(ctx context.Context, accountID string) ([
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (am *AuthentikManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
users, err := am.getAllUsers(ctx)
func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) {
users, err := am.getAllUsers()
if err != nil {
return nil, err
}
@@ -291,12 +291,12 @@ func (am *AuthentikManager) GetAllAccounts(ctx context.Context) (map[string][]*U
}
// getAllUsers returns all users in a Authentik account.
func (am *AuthentikManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
func (am *AuthentikManager) getAllUsers() ([]*UserData, error) {
users := make([]*UserData, 0)
page := int32(1)
for {
ctx, err := am.authenticationContext(ctx)
ctx, err := am.authenticationContext()
if err != nil {
return nil, err
}
@@ -329,14 +329,14 @@ func (am *AuthentikManager) getAllUsers(ctx context.Context) ([]*UserData, error
}
// CreateUser creates a new user in authentik Idp and sends an invitation.
func (am *AuthentikManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
func (am *AuthentikManager) CreateUser(_, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (am *AuthentikManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
ctx, err := am.authenticationContext(ctx)
func (am *AuthentikManager) GetUserByEmail(email string) ([]*UserData, error) {
ctx, err := am.authenticationContext()
if err != nil {
return nil, err
}
@@ -368,13 +368,13 @@ func (am *AuthentikManager) GetUserByEmail(ctx context.Context, email string) ([
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (am *AuthentikManager) InviteUserByID(_ context.Context, _ string) error {
func (am *AuthentikManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from Authentik
func (am *AuthentikManager) DeleteUser(ctx context.Context, userID string) error {
ctx, err := am.authenticationContext(ctx)
func (am *AuthentikManager) DeleteUser(userID string) error {
ctx, err := am.authenticationContext()
if err != nil {
return err
}
@@ -404,8 +404,8 @@ func (am *AuthentikManager) DeleteUser(ctx context.Context, userID string) error
return nil
}
func (am *AuthentikManager) authenticationContext(ctx context.Context) (context.Context, error) {
jwtToken, err := am.credentials.Authenticate(ctx)
func (am *AuthentikManager) authenticationContext() (context.Context, error) {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return nil, err
}

View File

@@ -1,7 +1,6 @@
package idp
import (
"context"
"fmt"
"io"
"strings"
@@ -139,7 +138,7 @@ func TestAuthentikRequestJWTToken(t *testing.T) {
helper: testCase.helper,
}
resp, err := creds.requestJWTToken(context.Background())
resp, err := creds.requestJWTToken()
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
@@ -305,7 +304,7 @@ func TestAuthentikAuthenticate(t *testing.T) {
}
creds.jwtToken.expiresInTime = testCase.inputExpireToken
_, err := creds.Authenticate(context.Background())
_, err := creds.Authenticate()
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")

View File

@@ -1,7 +1,6 @@
package idp
import (
"context"
"fmt"
"io"
"net/http"
@@ -111,7 +110,7 @@ func (ac *AzureCredentials) jwtStillValid() bool {
}
// requestJWTToken performs request to get jwt token.
func (ac *AzureCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
func (ac *AzureCredentials) requestJWTToken() (*http.Response, error) {
data := url.Values{}
data.Set("client_id", ac.clientConfig.ClientID)
data.Set("client_secret", ac.clientConfig.ClientSecret)
@@ -133,7 +132,7 @@ func (ac *AzureCredentials) requestJWTToken(ctx context.Context) (*http.Response
}
req.Header.Add("content-type", "application/x-www-form-urlencoded")
log.WithContext(ctx).Debug("requesting new jwt token for azure idp manager")
log.Debug("requesting new jwt token for azure idp manager")
resp, err := ac.httpClient.Do(req)
if err != nil {
@@ -185,7 +184,7 @@ func (ac *AzureCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTT
}
// Authenticate retrieves access token to use the azure Management API.
func (ac *AzureCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
func (ac *AzureCredentials) Authenticate() (JWTToken, error) {
ac.mux.Lock()
defer ac.mux.Unlock()
@@ -199,7 +198,7 @@ func (ac *AzureCredentials) Authenticate(ctx context.Context) (JWTToken, error)
return ac.jwtToken, nil
}
resp, err := ac.requestJWTToken(ctx)
resp, err := ac.requestJWTToken()
if err != nil {
return ac.jwtToken, err
}
@@ -216,16 +215,16 @@ func (ac *AzureCredentials) Authenticate(ctx context.Context) (JWTToken, error)
}
// CreateUser creates a new user in azure AD Idp.
func (am *AzureManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
func (am *AzureManager) CreateUser(_, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserDataByID requests user data from keycloak via ID.
func (am *AzureManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
q := url.Values{}
q.Add("$select", profileFields)
body, err := am.get(ctx, "users/"+userID, q)
body, err := am.get("users/"+userID, q)
if err != nil {
return nil, err
}
@@ -248,11 +247,11 @@ func (am *AzureManager) GetUserDataByID(ctx context.Context, userID string, appM
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (am *AzureManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
q := url.Values{}
q.Add("$select", profileFields)
body, err := am.get(ctx, "users/"+email, q)
body, err := am.get("users/"+email, q)
if err != nil {
return nil, err
}
@@ -274,8 +273,8 @@ func (am *AzureManager) GetUserByEmail(ctx context.Context, email string) ([]*Us
}
// GetAccount returns all the users for a given profile.
func (am *AzureManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
users, err := am.getAllUsers(ctx)
func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
users, err := am.getAllUsers()
if err != nil {
return nil, err
}
@@ -294,8 +293,8 @@ func (am *AzureManager) GetAccount(ctx context.Context, accountID string) ([]*Us
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (am *AzureManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
users, err := am.getAllUsers(ctx)
func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
users, err := am.getAllUsers()
if err != nil {
return nil, err
}
@@ -311,19 +310,19 @@ func (am *AzureManager) GetAllAccounts(ctx context.Context) (map[string][]*UserD
}
// UpdateUserAppMetadata updates user app metadata based on userID.
func (am *AzureManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
func (am *AzureManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
return nil
}
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (am *AzureManager) InviteUserByID(_ context.Context, _ string) error {
func (am *AzureManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from Azure.
func (am *AzureManager) DeleteUser(ctx context.Context, userID string) error {
jwtToken, err := am.credentials.Authenticate(ctx)
func (am *AzureManager) DeleteUser(userID string) error {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return err
}
@@ -336,7 +335,7 @@ func (am *AzureManager) DeleteUser(ctx context.Context, userID string) error {
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
log.WithContext(ctx).Debugf("delete idp user %s", userID)
log.Debugf("delete idp user %s", userID)
resp, err := am.httpClient.Do(req)
if err != nil {
@@ -359,7 +358,7 @@ func (am *AzureManager) DeleteUser(ctx context.Context, userID string) error {
}
// getAllUsers returns all users in an Azure AD account.
func (am *AzureManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
func (am *AzureManager) getAllUsers() ([]*UserData, error) {
users := make([]*UserData, 0)
q := url.Values{}
@@ -367,7 +366,7 @@ func (am *AzureManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
q.Add("$top", "500")
for nextLink := "users"; nextLink != ""; {
body, err := am.get(ctx, nextLink, q)
body, err := am.get(nextLink, q)
if err != nil {
return nil, err
}
@@ -392,8 +391,8 @@ func (am *AzureManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
}
// get perform Get requests.
func (am *AzureManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
jwtToken, err := am.credentials.Authenticate(ctx)
func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return nil, err
}

View File

@@ -1,7 +1,6 @@
package idp
import (
"context"
"fmt"
"testing"
"time"
@@ -102,7 +101,7 @@ func TestAzureAuthenticate(t *testing.T) {
}
creds.jwtToken.expiresInTime = testCase.inputExpireToken
_, err := creds.Authenticate(context.Background())
_, err := creds.Authenticate()
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")

View File

@@ -39,12 +39,12 @@ type GoogleWorkspaceCredentials struct {
appMetrics telemetry.AppMetrics
}
func (gc *GoogleWorkspaceCredentials) Authenticate(_ context.Context) (JWTToken, error) {
func (gc *GoogleWorkspaceCredentials) Authenticate() (JWTToken, error) {
return JWTToken{}, nil
}
// NewGoogleWorkspaceManager creates a new instance of the GoogleWorkspaceManager.
func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClientConfig, appMetrics telemetry.AppMetrics) (*GoogleWorkspaceManager, error) {
func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics telemetry.AppMetrics) (*GoogleWorkspaceManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
@@ -66,7 +66,7 @@ func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClient
}
// Create a new Admin SDK Directory service client
adminCredentials, err := getGoogleCredentials(ctx, config.ServiceAccountKey)
adminCredentials, err := getGoogleCredentials(config.ServiceAccountKey)
if err != nil {
return nil, err
}
@@ -90,12 +90,12 @@ func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClient
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
return nil
}
// GetUserDataByID requests user data from Google Workspace via ID.
func (gm *GoogleWorkspaceManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
user, err := gm.usersService.Get(userID).Do()
if err != nil {
return nil, err
@@ -112,7 +112,7 @@ func (gm *GoogleWorkspaceManager) GetUserDataByID(_ context.Context, userID stri
}
// GetAccount returns all the users for a given profile.
func (gm *GoogleWorkspaceManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, error) {
users, err := gm.getAllUsers()
if err != nil {
return nil, err
@@ -132,7 +132,7 @@ func (gm *GoogleWorkspaceManager) GetAccount(_ context.Context, accountID string
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (gm *GoogleWorkspaceManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
func (gm *GoogleWorkspaceManager) GetAllAccounts() (map[string][]*UserData, error) {
users, err := gm.getAllUsers()
if err != nil {
return nil, err
@@ -177,13 +177,13 @@ func (gm *GoogleWorkspaceManager) getAllUsers() ([]*UserData, error) {
}
// CreateUser creates a new user in Google Workspace and sends an invitation.
func (gm *GoogleWorkspaceManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
func (gm *GoogleWorkspaceManager) CreateUser(_, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (gm *GoogleWorkspaceManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, error) {
user, err := gm.usersService.Get(email).Do()
if err != nil {
return nil, err
@@ -201,12 +201,12 @@ func (gm *GoogleWorkspaceManager) GetUserByEmail(_ context.Context, email string
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (gm *GoogleWorkspaceManager) InviteUserByID(_ context.Context, _ string) error {
func (gm *GoogleWorkspaceManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from GoogleWorkspace.
func (gm *GoogleWorkspaceManager) DeleteUser(_ context.Context, userID string) error {
func (gm *GoogleWorkspaceManager) DeleteUser(userID string) error {
if err := gm.usersService.Delete(userID).Do(); err != nil {
return err
}
@@ -222,8 +222,8 @@ func (gm *GoogleWorkspaceManager) DeleteUser(_ context.Context, userID string) e
// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it.
// If that fails, it falls back to using the default Google credentials path.
// It returns the retrieved credentials or an error if unsuccessful.
func getGoogleCredentials(ctx context.Context, serviceAccountKey string) (*google.Credentials, error) {
log.WithContext(ctx).Debug("retrieving google credentials from the base64 encoded service account key")
func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error) {
log.Debug("retrieving google credentials from the base64 encoded service account key")
decodeKey, err := base64.StdEncoding.DecodeString(serviceAccountKey)
if err != nil {
return nil, fmt.Errorf("failed to decode service account key: %w", err)
@@ -239,8 +239,8 @@ func getGoogleCredentials(ctx context.Context, serviceAccountKey string) (*googl
return creds, nil
}
log.WithContext(ctx).Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err)
log.WithContext(ctx).Debug("falling back to default google credentials location")
log.Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err)
log.Debug("falling back to default google credentials location")
creds, err = google.FindDefaultCredentials(
context.Background(),

View File

@@ -1,7 +1,6 @@
package idp
import (
"context"
"fmt"
"net/http"
"strings"
@@ -17,14 +16,14 @@ const (
// Manager idp manager interface
type Manager interface {
UpdateUserAppMetadata(ctx context.Context, userId string, appMetadata AppMetadata) error
GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error)
GetAccount(ctx context.Context, accountId string) ([]*UserData, error)
GetAllAccounts(ctx context.Context) (map[string][]*UserData, error)
CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error)
GetUserByEmail(ctx context.Context, email string) ([]*UserData, error)
InviteUserByID(ctx context.Context, userID string) error
DeleteUser(ctx context.Context, userID string) error
UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error
GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error)
GetAccount(accountId string) ([]*UserData, error)
GetAllAccounts() (map[string][]*UserData, error)
CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error)
GetUserByEmail(email string) ([]*UserData, error)
InviteUserByID(userID string) error
DeleteUser(userID string) error
}
// ClientConfig defines common client configuration for all IdP manager
@@ -52,7 +51,7 @@ type Config struct {
// ManagerCredentials interface that authenticates using the credential of each type of idp
type ManagerCredentials interface {
Authenticate(ctx context.Context) (JWTToken, error)
Authenticate() (JWTToken, error)
}
// ManagerHTTPClient http client interface for API calls
@@ -92,7 +91,7 @@ type JWTToken struct {
}
// NewManager returns a new idp manager based on the configuration that it receives
func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetrics) (Manager, error) {
func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) {
if config.ClientConfig != nil {
config.ClientConfig.Issuer = strings.TrimSuffix(config.ClientConfig.Issuer, "/")
}
@@ -176,7 +175,7 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr
ServiceAccountKey: config.ExtraConfig["ServiceAccountKey"],
CustomerID: config.ExtraConfig["CustomerId"],
}
return NewGoogleWorkspaceManager(ctx, googleClientConfig, appMetrics)
return NewGoogleWorkspaceManager(googleClientConfig, appMetrics)
case "jumpcloud":
jumpcloudConfig := JumpCloudClientConfig{
APIToken: config.ExtraConfig["ApiToken"],

View File

@@ -74,7 +74,7 @@ func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppM
}
// Authenticate retrieves access token to use the JumpCloud user API.
func (jc *JumpCloudCredentials) Authenticate(_ context.Context) (JWTToken, error) {
func (jc *JumpCloudCredentials) Authenticate() (JWTToken, error) {
return JWTToken{}, nil
}
@@ -85,12 +85,12 @@ func (jm *JumpCloudManager) authenticationContext() context.Context {
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (jm *JumpCloudManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
func (jm *JumpCloudManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
return nil
}
// GetUserDataByID requests user data from JumpCloud via ID.
func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
func (jm *JumpCloudManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
authCtx := jm.authenticationContext()
user, resp, err := jm.client.SystemusersApi.SystemusersGet(authCtx, userID, contentType, accept, nil)
if err != nil {
@@ -116,7 +116,7 @@ func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, ap
}
// GetAccount returns all the users for a given profile.
func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
func (jm *JumpCloudManager) GetAccount(accountID string) ([]*UserData, error) {
authCtx := jm.authenticationContext()
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
if err != nil {
@@ -148,7 +148,7 @@ func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]*
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
func (jm *JumpCloudManager) GetAllAccounts() (map[string][]*UserData, error) {
authCtx := jm.authenticationContext()
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
if err != nil {
@@ -177,13 +177,13 @@ func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*Use
}
// CreateUser creates a new user in JumpCloud Idp and sends an invitation.
func (jm *JumpCloudManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
func (jm *JumpCloudManager) CreateUser(_, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
func (jm *JumpCloudManager) GetUserByEmail(email string) ([]*UserData, error) {
searchFilter := map[string]interface{}{
"searchFilter": map[string]interface{}{
"filter": []string{email},
@@ -219,12 +219,12 @@ func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]*
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (jm *JumpCloudManager) InviteUserByID(_ context.Context, _ string) error {
func (jm *JumpCloudManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from jumpCloud directory
func (jm *JumpCloudManager) DeleteUser(_ context.Context, userID string) error {
func (jm *JumpCloudManager) DeleteUser(userID string) error {
authCtx := jm.authenticationContext()
_, resp, err := jm.client.SystemusersApi.SystemusersDelete(authCtx, userID, contentType, accept, nil)
if err != nil {

View File

@@ -1,7 +1,6 @@
package idp
import (
"context"
"fmt"
"io"
"net/http"
@@ -110,7 +109,7 @@ func (kc *KeycloakCredentials) jwtStillValid() bool {
}
// requestJWTToken performs request to get jwt token.
func (kc *KeycloakCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
func (kc *KeycloakCredentials) requestJWTToken() (*http.Response, error) {
data := url.Values{}
data.Set("client_id", kc.clientConfig.ClientID)
data.Set("client_secret", kc.clientConfig.ClientSecret)
@@ -123,7 +122,7 @@ func (kc *KeycloakCredentials) requestJWTToken(ctx context.Context) (*http.Respo
}
req.Header.Add("content-type", "application/x-www-form-urlencoded")
log.WithContext(ctx).Debug("requesting new jwt token for keycloak idp manager")
log.Debug("requesting new jwt token for keycloak idp manager")
resp, err := kc.httpClient.Do(req)
if err != nil {
@@ -175,7 +174,7 @@ func (kc *KeycloakCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (J
}
// Authenticate retrieves access token to use the keycloak Management API.
func (kc *KeycloakCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) {
kc.mux.Lock()
defer kc.mux.Unlock()
@@ -189,7 +188,7 @@ func (kc *KeycloakCredentials) Authenticate(ctx context.Context) (JWTToken, erro
return kc.jwtToken, nil
}
resp, err := kc.requestJWTToken(ctx)
resp, err := kc.requestJWTToken()
if err != nil {
return kc.jwtToken, err
}
@@ -206,18 +205,18 @@ func (kc *KeycloakCredentials) Authenticate(ctx context.Context) (JWTToken, erro
}
// CreateUser creates a new user in keycloak Idp and sends an invite.
func (km *KeycloakManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
func (km *KeycloakManager) CreateUser(_, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (km *KeycloakManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
func (km *KeycloakManager) GetUserByEmail(email string) ([]*UserData, error) {
q := url.Values{}
q.Add("email", email)
q.Add("exact", "true")
body, err := km.get(ctx, "users", q)
body, err := km.get("users", q)
if err != nil {
return nil, err
}
@@ -241,8 +240,8 @@ func (km *KeycloakManager) GetUserByEmail(ctx context.Context, email string) ([]
}
// GetUserDataByID requests user data from keycloak via ID.
func (km *KeycloakManager) GetUserDataByID(ctx context.Context, userID string, _ AppMetadata) (*UserData, error) {
body, err := km.get(ctx, "users/"+userID, nil)
func (km *KeycloakManager) GetUserDataByID(userID string, _ AppMetadata) (*UserData, error) {
body, err := km.get("users/"+userID, nil)
if err != nil {
return nil, err
}
@@ -261,8 +260,8 @@ func (km *KeycloakManager) GetUserDataByID(ctx context.Context, userID string, _
}
// GetAccount returns all the users for a given account profile.
func (km *KeycloakManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
profiles, err := km.fetchAllUserProfiles(ctx)
func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) {
profiles, err := km.fetchAllUserProfiles()
if err != nil {
return nil, err
}
@@ -284,8 +283,8 @@ func (km *KeycloakManager) GetAccount(ctx context.Context, accountID string) ([]
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (km *KeycloakManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
profiles, err := km.fetchAllUserProfiles(ctx)
func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) {
profiles, err := km.fetchAllUserProfiles()
if err != nil {
return nil, err
}
@@ -304,19 +303,19 @@ func (km *KeycloakManager) GetAllAccounts(ctx context.Context) (map[string][]*Us
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (km *KeycloakManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
func (km *KeycloakManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
return nil
}
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (km *KeycloakManager) InviteUserByID(_ context.Context, _ string) error {
func (km *KeycloakManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from Keycloak by user ID.
func (km *KeycloakManager) DeleteUser(ctx context.Context, userID string) error {
jwtToken, err := km.credentials.Authenticate(ctx)
func (km *KeycloakManager) DeleteUser(userID string) error {
jwtToken, err := km.credentials.Authenticate()
if err != nil {
return err
}
@@ -354,8 +353,8 @@ func (km *KeycloakManager) DeleteUser(ctx context.Context, userID string) error
return nil
}
func (km *KeycloakManager) fetchAllUserProfiles(ctx context.Context) ([]keycloakProfile, error) {
totalUsers, err := km.totalUsersCount(ctx)
func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) {
totalUsers, err := km.totalUsersCount()
if err != nil {
return nil, err
}
@@ -363,7 +362,7 @@ func (km *KeycloakManager) fetchAllUserProfiles(ctx context.Context) ([]keycloak
q := url.Values{}
q.Add("max", fmt.Sprint(*totalUsers))
body, err := km.get(ctx, "users", q)
body, err := km.get("users", q)
if err != nil {
return nil, err
}
@@ -378,8 +377,8 @@ func (km *KeycloakManager) fetchAllUserProfiles(ctx context.Context) ([]keycloak
}
// get perform Get requests.
func (km *KeycloakManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
jwtToken, err := km.credentials.Authenticate(ctx)
func (km *KeycloakManager) get(resource string, q url.Values) ([]byte, error) {
jwtToken, err := km.credentials.Authenticate()
if err != nil {
return nil, err
}
@@ -415,8 +414,8 @@ func (km *KeycloakManager) get(ctx context.Context, resource string, q url.Value
// totalUsersCount returns the total count of all user created.
// Used when fetching all registered accounts with pagination.
func (km *KeycloakManager) totalUsersCount(ctx context.Context) (*int, error) {
body, err := km.get(ctx, "users/count", nil)
func (km *KeycloakManager) totalUsersCount() (*int, error) {
body, err := km.get("users/count", nil)
if err != nil {
return nil, err
}

View File

@@ -1,7 +1,6 @@
package idp
import (
"context"
"fmt"
"io"
"strings"
@@ -129,7 +128,7 @@ func TestKeycloakRequestJWTToken(t *testing.T) {
helper: testCase.helper,
}
resp, err := creds.requestJWTToken(context.Background())
resp, err := creds.requestJWTToken()
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
@@ -295,7 +294,7 @@ func TestKeycloakAuthenticate(t *testing.T) {
}
creds.jwtToken.expiresInTime = testCase.inputExpireToken
_, err := creds.Authenticate(context.Background())
_, err := creds.Authenticate()
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")

View File

@@ -1,79 +1,77 @@
package idp
import "context"
// MockIDP is a mock implementation of the IDP interface
type MockIDP struct {
UpdateUserAppMetadataFunc func(ctx context.Context, userId string, appMetadata AppMetadata) error
GetUserDataByIDFunc func(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error)
GetAccountFunc func(ctx context.Context, accountId string) ([]*UserData, error)
GetAllAccountsFunc func(ctx context.Context) (map[string][]*UserData, error)
CreateUserFunc func(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error)
GetUserByEmailFunc func(ctx context.Context, email string) ([]*UserData, error)
InviteUserByIDFunc func(ctx context.Context, userID string) error
DeleteUserFunc func(ctx context.Context, userID string) error
UpdateUserAppMetadataFunc func(userId string, appMetadata AppMetadata) error
GetUserDataByIDFunc func(userId string, appMetadata AppMetadata) (*UserData, error)
GetAccountFunc func(accountId string) ([]*UserData, error)
GetAllAccountsFunc func() (map[string][]*UserData, error)
CreateUserFunc func(email, name, accountID, invitedByEmail string) (*UserData, error)
GetUserByEmailFunc func(email string) ([]*UserData, error)
InviteUserByIDFunc func(userID string) error
DeleteUserFunc func(userID string) error
}
// UpdateUserAppMetadata is a mock implementation of the IDP interface UpdateUserAppMetadata method
func (m *MockIDP) UpdateUserAppMetadata(ctx context.Context, userId string, appMetadata AppMetadata) error {
func (m *MockIDP) UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error {
if m.UpdateUserAppMetadataFunc != nil {
return m.UpdateUserAppMetadataFunc(ctx, userId, appMetadata)
return m.UpdateUserAppMetadataFunc(userId, appMetadata)
}
return nil
}
// GetUserDataByID is a mock implementation of the IDP interface GetUserDataByID method
func (m *MockIDP) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) {
func (m *MockIDP) GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error) {
if m.GetUserDataByIDFunc != nil {
return m.GetUserDataByIDFunc(ctx, userId, appMetadata)
return m.GetUserDataByIDFunc(userId, appMetadata)
}
return nil, nil
}
// GetAccount is a mock implementation of the IDP interface GetAccount method
func (m *MockIDP) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) {
func (m *MockIDP) GetAccount(accountId string) ([]*UserData, error) {
if m.GetAccountFunc != nil {
return m.GetAccountFunc(ctx, accountId)
return m.GetAccountFunc(accountId)
}
return nil, nil
}
// GetAllAccounts is a mock implementation of the IDP interface GetAllAccounts method
func (m *MockIDP) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
func (m *MockIDP) GetAllAccounts() (map[string][]*UserData, error) {
if m.GetAllAccountsFunc != nil {
return m.GetAllAccountsFunc(ctx)
return m.GetAllAccountsFunc()
}
return nil, nil
}
// CreateUser is a mock implementation of the IDP interface CreateUser method
func (m *MockIDP) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
func (m *MockIDP) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
if m.CreateUserFunc != nil {
return m.CreateUserFunc(ctx, email, name, accountID, invitedByEmail)
return m.CreateUserFunc(email, name, accountID, invitedByEmail)
}
return nil, nil
}
// GetUserByEmail is a mock implementation of the IDP interface GetUserByEmail method
func (m *MockIDP) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
func (m *MockIDP) GetUserByEmail(email string) ([]*UserData, error) {
if m.GetUserByEmailFunc != nil {
return m.GetUserByEmailFunc(ctx, email)
return m.GetUserByEmailFunc(email)
}
return nil, nil
}
// InviteUserByID is a mock implementation of the IDP interface InviteUserByID method
func (m *MockIDP) InviteUserByID(ctx context.Context, userID string) error {
func (m *MockIDP) InviteUserByID(userID string) error {
if m.InviteUserByIDFunc != nil {
return m.InviteUserByIDFunc(ctx, userID)
return m.InviteUserByIDFunc(userID)
}
return nil
}
// DeleteUser is a mock implementation of the IDP interface DeleteUser method
func (m *MockIDP) DeleteUser(ctx context.Context, userID string) error {
func (m *MockIDP) DeleteUser(userID string) error {
if m.DeleteUserFunc != nil {
return m.DeleteUserFunc(ctx, userID)
return m.DeleteUserFunc(userID)
}
return nil
}

View File

@@ -94,17 +94,17 @@ func NewOktaManager(config OktaClientConfig, appMetrics telemetry.AppMetrics) (*
}
// Authenticate retrieves access token to use the okta user API.
func (oc *OktaCredentials) Authenticate(_ context.Context) (JWTToken, error) {
func (oc *OktaCredentials) Authenticate() (JWTToken, error) {
return JWTToken{}, nil
}
// CreateUser creates a new user in okta Idp and sends an invitation.
func (om *OktaManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
func (om *OktaManager) CreateUser(_, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserDataByID requests user data from keycloak via ID.
func (om *OktaManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
func (om *OktaManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
user, resp, err := om.client.User.GetUser(context.Background(), userID)
if err != nil {
return nil, err
@@ -132,7 +132,7 @@ func (om *OktaManager) GetUserDataByID(_ context.Context, userID string, appMeta
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (om *OktaManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
func (om *OktaManager) GetUserByEmail(email string) ([]*UserData, error) {
user, resp, err := om.client.User.GetUser(context.Background(), url.QueryEscape(email))
if err != nil {
return nil, err
@@ -160,7 +160,7 @@ func (om *OktaManager) GetUserByEmail(_ context.Context, email string) ([]*UserD
}
// GetAccount returns all the users for a given profile.
func (om *OktaManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) {
users, err := om.getAllUsers()
if err != nil {
return nil, err
@@ -180,7 +180,7 @@ func (om *OktaManager) GetAccount(_ context.Context, accountID string) ([]*UserD
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (om *OktaManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) {
users, err := om.getAllUsers()
if err != nil {
return nil, err
@@ -242,18 +242,18 @@ func (om *OktaManager) getAllUsers() ([]*UserData, error) {
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (om *OktaManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
func (om *OktaManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
return nil
}
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (om *OktaManager) InviteUserByID(_ context.Context, _ string) error {
func (om *OktaManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from Okta
func (om *OktaManager) DeleteUser(_ context.Context, userID string) error {
func (om *OktaManager) DeleteUser(userID string) error {
resp, err := om.client.User.DeactivateOrDeleteUser(context.Background(), userID, nil)
if err != nil {
return err

View File

@@ -1,7 +1,6 @@
package idp
import (
"context"
"fmt"
"io"
"net/http"
@@ -150,7 +149,7 @@ func (zc *ZitadelCredentials) jwtStillValid() bool {
}
// requestJWTToken performs request to get jwt token.
func (zc *ZitadelCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
func (zc *ZitadelCredentials) requestJWTToken() (*http.Response, error) {
data := url.Values{}
data.Set("client_id", zc.clientConfig.ClientID)
data.Set("client_secret", zc.clientConfig.ClientSecret)
@@ -164,7 +163,7 @@ func (zc *ZitadelCredentials) requestJWTToken(ctx context.Context) (*http.Respon
}
req.Header.Add("content-type", "application/x-www-form-urlencoded")
log.WithContext(ctx).Debug("requesting new jwt token for zitadel idp manager")
log.Debug("requesting new jwt token for zitadel idp manager")
resp, err := zc.httpClient.Do(req)
if err != nil {
@@ -216,7 +215,7 @@ func (zc *ZitadelCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JW
}
// Authenticate retrieves access token to use the Zitadel Management API.
func (zc *ZitadelCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) {
zc.mux.Lock()
defer zc.mux.Unlock()
@@ -230,7 +229,7 @@ func (zc *ZitadelCredentials) Authenticate(ctx context.Context) (JWTToken, error
return zc.jwtToken, nil
}
resp, err := zc.requestJWTToken(ctx)
resp, err := zc.requestJWTToken()
if err != nil {
return zc.jwtToken, err
}
@@ -247,7 +246,7 @@ func (zc *ZitadelCredentials) Authenticate(ctx context.Context) (JWTToken, error
}
// CreateUser creates a new user in zitadel Idp and sends an invite via Zitadel.
func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
firstLast := strings.SplitN(name, " ", 2)
var addUser = map[string]any{
@@ -270,7 +269,7 @@ func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name, accountID
return nil, err
}
body, err := zm.post(ctx, "users/human/_import", string(payload))
body, err := zm.post("users/human/_import", string(payload))
if err != nil {
return nil, err
}
@@ -301,7 +300,7 @@ func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name, accountID
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (zm *ZitadelManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) {
searchByEmail := zitadelAttributes{
"queries": {
{
@@ -317,7 +316,7 @@ func (zm *ZitadelManager) GetUserByEmail(ctx context.Context, email string) ([]*
return nil, err
}
body, err := zm.post(ctx, "users/_search", string(payload))
body, err := zm.post("users/_search", string(payload))
if err != nil {
return nil, err
}
@@ -341,8 +340,8 @@ func (zm *ZitadelManager) GetUserByEmail(ctx context.Context, email string) ([]*
}
// GetUserDataByID requests user data from zitadel via ID.
func (zm *ZitadelManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
body, err := zm.get(ctx, "users/"+userID, nil)
func (zm *ZitadelManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
body, err := zm.get("users/"+userID, nil)
if err != nil {
return nil, err
}
@@ -364,8 +363,8 @@ func (zm *ZitadelManager) GetUserDataByID(ctx context.Context, userID string, ap
}
// GetAccount returns all the users for a given profile.
func (zm *ZitadelManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
body, err := zm.post(ctx, "users/_search", "")
func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) {
body, err := zm.post("users/_search", "")
if err != nil {
return nil, err
}
@@ -393,8 +392,8 @@ func (zm *ZitadelManager) GetAccount(ctx context.Context, accountID string) ([]*
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (zm *ZitadelManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
body, err := zm.post(ctx, "users/_search", "")
func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) {
body, err := zm.post("users/_search", "")
if err != nil {
return nil, err
}
@@ -420,7 +419,7 @@ func (zm *ZitadelManager) GetAllAccounts(ctx context.Context) (map[string][]*Use
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
// Metadata values are base64 encoded.
func (zm *ZitadelManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
func (zm *ZitadelManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
return nil
}
@@ -430,7 +429,7 @@ type inviteUserRequest struct {
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (zm *ZitadelManager) InviteUserByID(ctx context.Context, userID string) error {
func (zm *ZitadelManager) InviteUserByID(userID string) error {
inviteUser := inviteUserRequest{
Email: userID,
}
@@ -441,14 +440,14 @@ func (zm *ZitadelManager) InviteUserByID(ctx context.Context, userID string) err
}
// don't care about the body in the response
_, err = zm.post(ctx, fmt.Sprintf("users/%s/_resend_initialization", userID), string(payload))
_, err = zm.post(fmt.Sprintf("users/%s/_resend_initialization", userID), string(payload))
return err
}
// DeleteUser from Zitadel
func (zm *ZitadelManager) DeleteUser(ctx context.Context, userID string) error {
func (zm *ZitadelManager) DeleteUser(userID string) error {
resource := fmt.Sprintf("users/%s", userID)
if err := zm.delete(ctx, resource); err != nil {
if err := zm.delete(resource); err != nil {
return err
}
@@ -460,8 +459,8 @@ func (zm *ZitadelManager) DeleteUser(ctx context.Context, userID string) error {
}
// post perform Post requests.
func (zm *ZitadelManager) post(ctx context.Context, resource string, body string) ([]byte, error) {
jwtToken, err := zm.credentials.Authenticate(ctx)
func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) {
jwtToken, err := zm.credentials.Authenticate()
if err != nil {
return nil, err
}
@@ -496,8 +495,8 @@ func (zm *ZitadelManager) post(ctx context.Context, resource string, body string
}
// delete perform Delete requests.
func (zm *ZitadelManager) delete(ctx context.Context, resource string) error {
jwtToken, err := zm.credentials.Authenticate(ctx)
func (zm *ZitadelManager) delete(resource string) error {
jwtToken, err := zm.credentials.Authenticate()
if err != nil {
return err
}
@@ -532,8 +531,8 @@ func (zm *ZitadelManager) delete(ctx context.Context, resource string) error {
}
// get perform Get requests.
func (zm *ZitadelManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
jwtToken, err := zm.credentials.Authenticate(ctx)
func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) {
jwtToken, err := zm.credentials.Authenticate()
if err != nil {
return nil, err
}

View File

@@ -1,7 +1,6 @@
package idp
import (
"context"
"fmt"
"io"
"strings"
@@ -109,7 +108,7 @@ func TestZitadelRequestJWTToken(t *testing.T) {
helper: testCase.helper,
}
resp, err := creds.requestJWTToken(context.Background())
resp, err := creds.requestJWTToken()
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
@@ -275,7 +274,7 @@ func TestZitadelAuthenticate(t *testing.T) {
}
creds.jwtToken.expiresInTime = testCase.inputExpireToken
_, err := creds.Authenticate(context.Background())
_, err := creds.Authenticate()
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")

View File

@@ -1,10 +1,9 @@
package server
import (
"context"
"errors"
log "github.com/sirupsen/logrus"
"github.com/google/martian/v3/log"
"github.com/netbirdio/netbird/management/server/account"
)
@@ -20,22 +19,22 @@ import (
//
// Returns:
// - error: An error if any occurred during the process, otherwise returns nil
func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error {
ok, err := am.GroupValidation(ctx, accountID, groups)
func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error {
ok, err := am.GroupValidation(accountID, groups)
if err != nil {
log.WithContext(ctx).Debugf("error validating groups: %s", err.Error())
log.Debugf("error validating groups: %s", err.Error())
return err
}
if !ok {
log.WithContext(ctx).Debugf("invalid groups")
log.Debugf("invalid groups")
return errors.New("invalid groups")
}
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
a, err := am.Store.GetAccountByUser(ctx, userID)
a, err := am.Store.GetAccountByUser(userID)
if err != nil {
return err
}
@@ -49,14 +48,14 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
a.Settings.Extra = extra
}
extra.IntegratedValidatorGroups = groups
return am.Store.SaveAccount(ctx, a)
return am.Store.SaveAccount(a)
}
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) {
func (am *DefaultAccountManager) GroupValidation(accountId string, groups []string) (bool, error) {
if len(groups) == 0 {
return true, nil
}
accountsGroups, err := am.ListGroups(ctx, accountId)
accountsGroups, err := am.ListGroups(accountId)
if err != nil {
return false, err
}

View File

@@ -1,8 +1,6 @@
package integrated_validator
import (
"context"
"github.com/netbirdio/netbird/management/server/account"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -10,12 +8,12 @@ import (
// IntegratedValidator interface exists to avoid the circle dependencies
type IntegratedValidator interface {
ValidateExtraSettings(ctx context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error)
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error)
ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error)
PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer
IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error)
GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error)
PeerDeleted(ctx context.Context, accountID, peerID string) error
PeerDeleted(accountID, peerID string) error
SetPeerInvalidationListener(fn func(accountID string))
Stop(ctx context.Context)
Stop()
}

View File

@@ -2,7 +2,6 @@ package jwtclaims
import (
"bytes"
"context"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
@@ -70,8 +69,8 @@ type JWTValidator struct {
}
// NewJWTValidator constructor
func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) {
keys, err := getPemKeys(ctx, keysLocation)
func NewJWTValidator(issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) {
keys, err := getPemKeys(keysLocation)
if err != nil {
return nil, err
}
@@ -103,19 +102,19 @@ func NewJWTValidator(ctx context.Context, issuer string, audienceList []string,
lock.Lock()
defer lock.Unlock()
refreshedKeys, err := getPemKeys(ctx, keysLocation)
refreshedKeys, err := getPemKeys(keysLocation)
if err != nil {
log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
log.Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
refreshedKeys = keys
}
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
log.Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
keys = refreshedKeys
}
}
cert, err := getPemCert(ctx, token, keys)
cert, err := getPemCert(token, keys)
if err != nil {
return nil, err
}
@@ -137,19 +136,19 @@ func NewJWTValidator(ctx context.Context, issuer string, audienceList []string,
}
// ValidateAndParse validates the token and returns the parsed token
func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) {
// If the token is empty...
if token == "" {
// Check if it was required
if m.options.CredentialsOptional {
log.WithContext(ctx).Debugf("no credentials found (CredentialsOptional=true)")
log.Debugf("no credentials found (CredentialsOptional=true)")
// No error, just no token (and that is ok given that CredentialsOptional is true)
return nil, nil //nolint:nilnil
}
// If we get here, the required token is missing
errorMsg := "required authorization token not found"
log.WithContext(ctx).Debugf(" Error: No credentials found (CredentialsOptional=false)")
log.Debugf(" Error: No credentials found (CredentialsOptional=false)")
return nil, fmt.Errorf(errorMsg)
}
@@ -158,7 +157,7 @@ func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt
// Check if there was an error in parsing...
if err != nil {
log.WithContext(ctx).Errorf("error parsing token: %v", err)
log.Errorf("error parsing token: %v", err)
return nil, fmt.Errorf("Error parsing token: %w", err)
}
@@ -166,14 +165,14 @@ func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt
errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s",
m.options.SigningMethod.Alg(),
parsedToken.Header["alg"])
log.WithContext(ctx).Debugf("error validating token algorithm: %s", errorMsg)
log.Debugf("error validating token algorithm: %s", errorMsg)
return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg)
}
// Check if the parsed token is valid...
if !parsedToken.Valid {
errorMsg := "token is invalid"
log.WithContext(ctx).Debugf(errorMsg)
log.Debugf(errorMsg)
return nil, errors.New(errorMsg)
}
@@ -185,7 +184,7 @@ func (jwks *Jwks) stillValid() bool {
return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime)
}
func getPemKeys(ctx context.Context, keysLocation string) (*Jwks, error) {
func getPemKeys(keysLocation string) (*Jwks, error) {
resp, err := http.Get(keysLocation)
if err != nil {
return nil, err
@@ -199,13 +198,13 @@ func getPemKeys(ctx context.Context, keysLocation string) (*Jwks, error) {
}
cacheControlHeader := resp.Header.Get("Cache-Control")
expiresIn := getMaxAgeFromCacheHeader(ctx, cacheControlHeader)
expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader)
jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second)
return jwks, err
}
func getPemCert(ctx context.Context, token *jwt.Token, jwks *Jwks) (string, error) {
func getPemCert(token *jwt.Token, jwks *Jwks) (string, error) {
// todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time
cert := ""
@@ -218,7 +217,7 @@ func getPemCert(ctx context.Context, token *jwt.Token, jwks *Jwks) (string, erro
cert = "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
return cert, nil
}
log.WithContext(ctx).Debugf("generating validation pem from JWK")
log.Debugf("generating validation pem from JWK")
return generatePemFromJWK(jwks.Keys[k])
}
@@ -285,7 +284,7 @@ func convertExponentStringToInt(stringExponent string) (int, error) {
}
// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header
func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int {
func getMaxAgeFromCacheHeader(cacheControl string) int {
// Split into individual directives
directives := strings.Split(cacheControl, ",")
@@ -296,7 +295,7 @@ func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int {
maxAgeStr := strings.TrimPrefix(directive, "max-age=")
maxAge, err := strconv.Atoi(maxAgeStr)
if err != nil {
log.WithContext(ctx).Debugf("error parsing max-age: %v", err)
log.Debugf("error parsing max-age: %v", err)
return 0
}

View File

@@ -406,7 +406,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := NewTestStoreFromJson(context.Background(), config.Datadir)
store, cleanUp, err := NewTestStoreFromJson(config.Datadir)
if err != nil {
return nil, "", err
}
@@ -414,7 +414,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
peersUpdateManager := NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
accountManager, err := BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted",
accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MocIntegratedValidator{})
if err != nil {
return nil, "", err
@@ -422,7 +422,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
ephemeralMgr := NewEphemeralManager(store, accountManager)
mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr)
mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr)
if err != nil {
return nil, "", err
}

Some files were not shown because too many files have changed in this diff Show More