mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-03 08:03:46 -04:00
Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
840b07c784 | ||
|
|
85e991ff78 | ||
|
|
f9845e53a0 | ||
|
|
765aba2c1c | ||
|
|
7cb81f1d70 | ||
|
|
cea19de667 | ||
|
|
29e5eceb6b | ||
|
|
0f63737330 | ||
|
|
bf518c5fba | ||
|
|
eab6183a8e | ||
|
|
4517da8b3a | ||
|
|
9c0d923124 | ||
|
|
6857734c48 | ||
|
|
3b019800f8 | ||
|
|
4cd4f88666 | ||
|
|
d2157bda66 | ||
|
|
43a8ba97e3 | ||
|
|
17874771cc | ||
|
|
f6ccf6b97a | ||
|
|
6aae797baf | ||
|
|
aca054e51e | ||
|
|
10cee8f46e | ||
|
|
628673db20 | ||
|
|
eaa31c2dc6 | ||
|
|
25723e9b07 | ||
|
|
3cf4d5758f |
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -173,7 +173,7 @@ jobs:
|
||||
retention-days: 3
|
||||
|
||||
release_ui_darwin:
|
||||
runs-on: macos-11
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
|
||||
59
.github/workflows/test-infrastructure-files.yml
vendored
59
.github/workflows/test-infrastructure-files.yml
vendored
@@ -178,34 +178,79 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: run script
|
||||
- name: run script with Zitadel PostgreSQL
|
||||
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
|
||||
|
||||
- name: test Caddy file gen
|
||||
- name: test Caddy file gen postgres
|
||||
run: test -f Caddyfile
|
||||
- name: test docker-compose file gen
|
||||
|
||||
- name: test docker-compose file gen postgres
|
||||
run: test -f docker-compose.yml
|
||||
- name: test management.json file gen
|
||||
|
||||
- name: test management.json file gen postgres
|
||||
run: test -f management.json
|
||||
- name: test turnserver.conf file gen
|
||||
|
||||
- name: test turnserver.conf file gen postgres
|
||||
run: |
|
||||
set -x
|
||||
test -f turnserver.conf
|
||||
grep external-ip turnserver.conf
|
||||
- name: test zitadel.env file gen
|
||||
|
||||
- name: test zitadel.env file gen postgres
|
||||
run: test -f zitadel.env
|
||||
- name: test dashboard.env file gen
|
||||
|
||||
- name: test dashboard.env file gen postgres
|
||||
run: test -f dashboard.env
|
||||
|
||||
- name: test zdb.env file gen postgres
|
||||
run: test -f zdb.env
|
||||
|
||||
- name: Postgres run cleanup
|
||||
run: |
|
||||
docker-compose down --volumes --rmi all
|
||||
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
|
||||
|
||||
- name: run script with Zitadel CockroachDB
|
||||
run: bash -x infrastructure_files/getting-started-with-zitadel.sh
|
||||
env:
|
||||
NETBIRD_DOMAIN: use-ip
|
||||
ZITADEL_DATABASE: cockroach
|
||||
|
||||
- name: test Caddy file gen CockroachDB
|
||||
run: test -f Caddyfile
|
||||
|
||||
- name: test docker-compose file gen CockroachDB
|
||||
run: test -f docker-compose.yml
|
||||
|
||||
- name: test management.json file gen CockroachDB
|
||||
run: test -f management.json
|
||||
|
||||
- name: test turnserver.conf file gen CockroachDB
|
||||
run: |
|
||||
set -x
|
||||
test -f turnserver.conf
|
||||
grep external-ip turnserver.conf
|
||||
|
||||
- name: test zitadel.env file gen CockroachDB
|
||||
run: test -f zitadel.env
|
||||
|
||||
- name: test dashboard.env file gen CockroachDB
|
||||
run: test -f dashboard.env
|
||||
|
||||
test-download-geolite2-script:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install jq
|
||||
run: sudo apt-get update && sudo apt-get install -y unzip sqlite3
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: test script
|
||||
run: bash -x infrastructure_files/download-geolite2.sh
|
||||
|
||||
- name: test mmdb file exists
|
||||
run: test -f GeoLite2-City.mmdb
|
||||
|
||||
- name: test geonames file exists
|
||||
run: test -f geonames.db
|
||||
|
||||
@@ -3,8 +3,10 @@ builds:
|
||||
- id: netbird-ui-darwin
|
||||
dir: client/ui
|
||||
binary: netbird-ui
|
||||
env: [CGO_ENABLED=1]
|
||||
|
||||
env:
|
||||
- CGO_ENABLED=1
|
||||
- MACOSX_DEPLOYMENT_TARGET=11.0
|
||||
- MACOS_DEPLOYMENT_TARGET=11.0
|
||||
goos:
|
||||
- darwin
|
||||
goarch:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM alpine:3.18.5
|
||||
FROM alpine:3.19
|
||||
RUN apk add --no-cache ca-certificates iptables ip6tables
|
||||
ENV NB_FOREGROUND_MODE=true
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
||||
|
||||
@@ -76,7 +76,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
store, cleanUp, err := mgmt.NewTestStoreFromJson(config.Datadir)
|
||||
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -87,13 +87,13 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
iv, _ := integrations.NewIntegratedValidator(eventStore)
|
||||
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
|
||||
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -74,12 +74,12 @@ func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = i.insertRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
|
||||
err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = i.insertRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
|
||||
err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -101,6 +101,7 @@ func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string,
|
||||
}
|
||||
delete(i.rules, ruleKey)
|
||||
}
|
||||
|
||||
err = i.iptablesClient.Insert(table, chain, 1, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
||||
@@ -317,6 +318,13 @@ func (i *routerManager) createChain(table, newChain string) error {
|
||||
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
|
||||
}
|
||||
|
||||
// Add the loopback return rule to the NAT chain
|
||||
loopbackRule := []string{"-o", "lo", "-j", "RETURN"}
|
||||
err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err)
|
||||
}
|
||||
|
||||
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
|
||||
@@ -326,6 +334,30 @@ func (i *routerManager) createChain(table, newChain string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// addNATRule appends an iptables rule pair to the nat chain
|
||||
func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
||||
rule := genRuleSpec(jump, pair.Source, pair.Destination)
|
||||
existingRule, found := i.rules[ruleKey]
|
||||
if found {
|
||||
err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
delete(i.rules, ruleKey)
|
||||
}
|
||||
|
||||
// inserting after loopback ignore rule
|
||||
err := i.iptablesClient.Insert(table, chain, 2, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
i.rules[ruleKey] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// genRuleSpec generates rule specification
|
||||
func genRuleSpec(jump, source, destination string) []string {
|
||||
return []string{"-s", source, "-d", destination, "-j", jump}
|
||||
|
||||
@@ -95,7 +95,7 @@ func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.InsertRoutingRules(pair)
|
||||
return m.router.AddRoutingRules(pair)
|
||||
}
|
||||
|
||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||
|
||||
@@ -22,6 +22,8 @@ const (
|
||||
|
||||
userDataAcceptForwardRuleSrc = "frwacceptsrc"
|
||||
userDataAcceptForwardRuleDst = "frwacceptdst"
|
||||
|
||||
loopbackInterface = "lo\x00"
|
||||
)
|
||||
|
||||
// some presets for building nftable rules
|
||||
@@ -126,6 +128,22 @@ func (r *router) createContainers() error {
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
// Add RETURN rule for loopback interface
|
||||
loRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte(loopbackInterface),
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictReturn},
|
||||
},
|
||||
}
|
||||
r.conn.InsertRule(loRule)
|
||||
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
@@ -138,28 +156,28 @@ func (r *router) createContainers() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
|
||||
func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
|
||||
// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain
|
||||
func (r *router) AddRoutingRules(pair manager.RouterPair) error {
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.insertRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
|
||||
err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.insertRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
|
||||
err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if pair.Masquerade {
|
||||
err = r.insertRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
|
||||
err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.insertRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
|
||||
err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -177,8 +195,8 @@ func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// insertRoutingRule inserts a nftable rule to the conn client flush queue
|
||||
func (r *router) insertRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
|
||||
// addRoutingRule inserts a nftable rule to the conn client flush queue
|
||||
func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
|
||||
@@ -199,7 +217,7 @@ func (r *router) insertRoutingRule(format, chainName string, pair manager.Router
|
||||
}
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
|
||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainName],
|
||||
Exprs: expression,
|
||||
|
||||
@@ -47,7 +47,7 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
||||
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.InsertRoutingRules(testCase.InputPair)
|
||||
err = manager.AddRoutingRules(testCase.InputPair)
|
||||
defer func() {
|
||||
_ = manager.RemoveRoutingRules(testCase.InputPair)
|
||||
}()
|
||||
|
||||
@@ -78,6 +78,11 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}()
|
||||
|
||||
log.WithField("question", r.Question[0]).Trace("received an upstream question")
|
||||
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
|
||||
if r.Extra == nil {
|
||||
r.SetEdns0(4096, false)
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
|
||||
select {
|
||||
case <-u.ctx.Done():
|
||||
|
||||
@@ -282,8 +282,6 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
|
||||
|
||||
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, e.config.WgPort)
|
||||
|
||||
wgIface, err := e.newWgIface()
|
||||
if err != nil {
|
||||
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err)
|
||||
@@ -291,6 +289,9 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
e.wgInterface = wgIface
|
||||
|
||||
userspace := e.wgInterface.IsUserspaceBind()
|
||||
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort)
|
||||
|
||||
if e.config.RosenpassEnabled {
|
||||
log.Infof("rosenpass is enabled")
|
||||
if e.config.RosenpassPermissive {
|
||||
@@ -1464,6 +1465,15 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
|
||||
return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs)
|
||||
}
|
||||
|
||||
func (e *Engine) restartEngine() {
|
||||
if err := e.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
if err := e.Start(); err != nil {
|
||||
log.Errorf("Failed to start engine: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) startNetworkMonitor() {
|
||||
if !e.config.NetworkMonitor {
|
||||
log.Infof("Network monitor is disabled, not starting")
|
||||
@@ -1472,14 +1482,29 @@ func (e *Engine) startNetworkMonitor() {
|
||||
|
||||
e.networkMonitor = networkmonitor.New()
|
||||
go func() {
|
||||
var mu sync.Mutex
|
||||
var debounceTimer *time.Timer
|
||||
|
||||
// Start the network monitor with a callback, Start will block until the monitor is stopped,
|
||||
// a network change is detected, or an error occurs on start up
|
||||
err := e.networkMonitor.Start(e.ctx, func() {
|
||||
log.Infof("Network monitor detected network change, restarting engine")
|
||||
if err := e.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
if err := e.Start(); err != nil {
|
||||
log.Errorf("Failed to start engine: %v", err)
|
||||
// This function is called when a network change is detected
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if debounceTimer != nil {
|
||||
debounceTimer.Stop()
|
||||
}
|
||||
|
||||
// Set a new timer to debounce rapid network changes
|
||||
debounceTimer = time.AfterFunc(1*time.Second, func() {
|
||||
// This function is called after the debounce period
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
log.Infof("Network monitor detected network change, restarting engine")
|
||||
e.restartEngine()
|
||||
})
|
||||
})
|
||||
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
|
||||
log.Errorf("Network monitor: %v", err)
|
||||
|
||||
@@ -174,7 +174,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
//time.Sleep(250 * time.Millisecond)
|
||||
// time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=")
|
||||
|
||||
@@ -1057,7 +1057,7 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
store, cleanUp, err := server.NewTestStoreFromJson(config.Datadir)
|
||||
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@@ -1068,13 +1068,13 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
ia, _ := integrations.NewIntegratedValidator(eventStore)
|
||||
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -45,24 +45,6 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||
|
||||
switch msg.Type {
|
||||
|
||||
// handle interface state changes
|
||||
case unix.RTM_IFINFO:
|
||||
ifinfo, err := parseInterfaceMessage(buf[:n])
|
||||
if err != nil {
|
||||
log.Errorf("Network monitor: error parsing interface message: %v", err)
|
||||
continue
|
||||
}
|
||||
if msg.Flags&unix.IFF_UP != 0 {
|
||||
continue
|
||||
}
|
||||
if (nexthopv4.Intf == nil || ifinfo.Index != nexthopv4.Intf.Index) && (nexthopv6.Intf == nil || ifinfo.Index != nexthopv6.Intf.Index) {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Infof("Network monitor: monitored interface (%s) is down.", ifinfo.Name)
|
||||
go callback()
|
||||
|
||||
// handle route changes
|
||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||
route, err := parseRouteMessage(buf[:n])
|
||||
@@ -94,24 +76,6 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
}
|
||||
}
|
||||
|
||||
func parseInterfaceMessage(buf []byte) (*route.InterfaceMessage, error) {
|
||||
msgs, err := route.ParseRIB(route.RIBTypeInterface, buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||
}
|
||||
|
||||
if len(msgs) != 1 {
|
||||
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||
}
|
||||
|
||||
msg, ok := msgs[0].(*route.InterfaceMessage)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
|
||||
if err != nil {
|
||||
|
||||
@@ -19,14 +19,9 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
return errors.New("no interfaces available")
|
||||
}
|
||||
|
||||
linkChan := make(chan netlink.LinkUpdate)
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
|
||||
if err := netlink.LinkSubscribe(linkChan, done); err != nil {
|
||||
return fmt.Errorf("subscribe to link updates: %v", err)
|
||||
}
|
||||
|
||||
routeChan := make(chan netlink.RouteUpdate)
|
||||
if err := netlink.RouteSubscribe(routeChan, done); err != nil {
|
||||
return fmt.Errorf("subscribe to route updates: %v", err)
|
||||
@@ -38,25 +33,6 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
case <-ctx.Done():
|
||||
return ErrStopped
|
||||
|
||||
// handle interface state changes
|
||||
case update := <-linkChan:
|
||||
if (nexthopv4.Intf == nil || update.Index != int32(nexthopv4.Intf.Index)) && (nexthopv6.Intf == nil || update.Index != int32(nexthopv6.Intf.Index)) {
|
||||
continue
|
||||
}
|
||||
|
||||
switch update.Header.Type {
|
||||
case syscall.RTM_DELLINK:
|
||||
log.Infof("Network monitor: monitored interface (%s) is gone", update.Link.Attrs().Name)
|
||||
go callback()
|
||||
return nil
|
||||
case syscall.RTM_NEWLINK:
|
||||
if (update.IfInfomsg.Flags&syscall.IFF_RUNNING) == 0 && update.Link.Attrs().OperState == netlink.OperDown {
|
||||
log.Infof("Network monitor: monitored interface (%s) is down.", update.Link.Attrs().Name)
|
||||
go callback()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// handle route changes
|
||||
case route := <-routeChan:
|
||||
// default route and main table
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -33,12 +34,8 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
return fmt.Errorf("get neighbors: %w", err)
|
||||
}
|
||||
|
||||
if n, ok := initialNeighbors[nexthopv4.IP]; ok {
|
||||
neighborv4 = &n
|
||||
}
|
||||
if n, ok := initialNeighbors[nexthopv6.IP]; ok {
|
||||
neighborv6 = &n
|
||||
}
|
||||
neighborv4 = assignNeighbor(nexthopv4, initialNeighbors)
|
||||
neighborv6 = assignNeighbor(nexthopv6, initialNeighbors)
|
||||
}
|
||||
log.Debugf("Network monitor: initial IPv4 neighbor: %v, IPv6 neighbor: %v", neighborv4, neighborv6)
|
||||
|
||||
@@ -58,6 +55,16 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
}
|
||||
}
|
||||
|
||||
func assignNeighbor(nexthop systemops.Nexthop, initialNeighbors map[netip.Addr]systemops.Neighbor) *systemops.Neighbor {
|
||||
if n, ok := initialNeighbors[nexthop.IP]; ok &&
|
||||
n.State != unreachable &&
|
||||
n.State != incomplete &&
|
||||
n.State != tbd {
|
||||
return &n
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func changed(
|
||||
nexthopv4 systemops.Nexthop,
|
||||
neighborv4 *systemops.Neighbor,
|
||||
@@ -87,37 +94,64 @@ func changed(
|
||||
}
|
||||
|
||||
// routeChanged checks if the default routes still point to our nexthop/interface
|
||||
func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes map[netip.Prefix]systemops.Route) bool {
|
||||
func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route) bool {
|
||||
if !nexthop.IP.IsValid() {
|
||||
return false
|
||||
}
|
||||
|
||||
var unspec netip.Prefix
|
||||
if nexthop.IP.Is6() {
|
||||
unspec = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
} else {
|
||||
unspec = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
}
|
||||
unspec := getUnspecifiedPrefix(nexthop.IP)
|
||||
defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec)
|
||||
|
||||
if r, ok := routes[unspec]; ok {
|
||||
if r.Nexthop != nexthop.IP || compareIntf(r.Interface, intf) != 0 {
|
||||
oldIntf, newIntf := "<nil>", "<nil>"
|
||||
if intf != nil {
|
||||
oldIntf = intf.Name
|
||||
}
|
||||
if r.Interface != nil {
|
||||
newIntf = r.Interface.Name
|
||||
}
|
||||
log.Infof("network monitor: default route changed: %s from %s (%s) to %s (%s)", r.Destination, nexthop.IP, oldIntf, r.Nexthop, newIntf)
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
log.Infof("network monitor: default route is gone")
|
||||
log.Tracef("network monitor: all default routes:\n%s", strings.Join(defaultRoutes, "\n"))
|
||||
|
||||
if !foundMatchingRoute {
|
||||
logRouteChange(nexthop.IP, intf)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func getUnspecifiedPrefix(ip netip.Addr) netip.Prefix {
|
||||
if ip.Is6() {
|
||||
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
}
|
||||
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
}
|
||||
|
||||
func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) {
|
||||
var defaultRoutes []string
|
||||
foundMatchingRoute := false
|
||||
|
||||
for _, r := range routes {
|
||||
if r.Destination == unspec {
|
||||
routeInfo := formatRouteInfo(r)
|
||||
defaultRoutes = append(defaultRoutes, routeInfo)
|
||||
|
||||
if r.Nexthop == nexthop.IP && compareIntf(r.Interface, intf) == 0 {
|
||||
foundMatchingRoute = true
|
||||
log.Debugf("network monitor: found matching default route: %s", routeInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return defaultRoutes, foundMatchingRoute
|
||||
}
|
||||
|
||||
func formatRouteInfo(r systemops.Route) string {
|
||||
newIntf := "<nil>"
|
||||
if r.Interface != nil {
|
||||
newIntf = r.Interface.Name
|
||||
}
|
||||
return fmt.Sprintf("Nexthop: %s, Interface: %s", r.Nexthop, newIntf)
|
||||
}
|
||||
|
||||
func logRouteChange(ip netip.Addr, intf *net.Interface) {
|
||||
oldIntf := "<nil>"
|
||||
if intf != nil {
|
||||
oldIntf = intf.Name
|
||||
}
|
||||
log.Infof("network monitor: default route for %s (%s) is gone or changed", ip, oldIntf)
|
||||
}
|
||||
|
||||
func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, neighbors map[netip.Addr]systemops.Neighbor) bool {
|
||||
@@ -127,7 +161,7 @@ func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, ne
|
||||
|
||||
// TODO: consider non-local nexthops, e.g. on point-to-point interfaces
|
||||
if n, ok := neighbors[nexthop.IP]; ok {
|
||||
if n.State != reachable && n.State != permanent {
|
||||
if n.State == unreachable || n.State == incomplete {
|
||||
log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State))
|
||||
return true
|
||||
} else if n.InterfaceIndex != neighbor.InterfaceIndex {
|
||||
@@ -165,18 +199,13 @@ func getNeighbors() (map[netip.Addr]systemops.Neighbor, error) {
|
||||
return neighbours, nil
|
||||
}
|
||||
|
||||
func getRoutes() (map[netip.Prefix]systemops.Route, error) {
|
||||
func getRoutes() ([]systemops.Route, error) {
|
||||
entries, err := systemops.GetRoutes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get routes: %w", err)
|
||||
}
|
||||
|
||||
routes := make(map[netip.Prefix]systemops.Route, len(entries))
|
||||
for _, entry := range entries {
|
||||
routes[entry.Destination] = entry
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func stateFromInt(state uint8) string {
|
||||
|
||||
@@ -36,7 +36,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_GetKey(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
@@ -51,7 +51,7 @@ func TestConn_GetKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
@@ -88,7 +88,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
@@ -124,7 +124,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
func TestConn_Status(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
@@ -154,7 +154,7 @@ func TestConn_Status(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_Close(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
|
||||
@@ -23,7 +23,8 @@ import (
|
||||
const (
|
||||
DefaultInterval = time.Minute
|
||||
|
||||
minInterval = 2 * time.Second
|
||||
minInterval = 2 * time.Second
|
||||
failureInterval = 5 * time.Second
|
||||
|
||||
addAllowedIP = "add allowed IP %s: %w"
|
||||
)
|
||||
@@ -160,7 +161,12 @@ func (r *Route) startResolver(ctx context.Context) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
r.update(ctx)
|
||||
if err := r.update(ctx); err != nil {
|
||||
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
|
||||
if interval > failureInterval {
|
||||
ticker.Reset(failureInterval)
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -168,17 +174,28 @@ func (r *Route) startResolver(ctx context.Context) {
|
||||
log.Debugf("Stopping dynamic route resolver for domains [%v]", r)
|
||||
return
|
||||
case <-ticker.C:
|
||||
r.update(ctx)
|
||||
if err := r.update(ctx); err != nil {
|
||||
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
|
||||
// Use a lower ticker interval if the update fails
|
||||
if interval > failureInterval {
|
||||
ticker.Reset(failureInterval)
|
||||
}
|
||||
} else if interval > failureInterval {
|
||||
// Reset to the original interval if the update succeeds
|
||||
ticker.Reset(interval)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Route) update(ctx context.Context) {
|
||||
func (r *Route) update(ctx context.Context) error {
|
||||
if resolved, err := r.resolveDomains(); err != nil {
|
||||
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
|
||||
return fmt.Errorf("resolve domains: %w", err)
|
||||
} else if err := r.updateDynamicRoutes(ctx, resolved); err != nil {
|
||||
log.Errorf("Failed to update dynamic routes for [%v]: %v", r, err)
|
||||
return fmt.Errorf("update dynamic routes: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Route) resolveDomains() (domainMap, error) {
|
||||
|
||||
@@ -8,9 +8,13 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func NewFactory(ctx context.Context, wgPort int) *Factory {
|
||||
func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
|
||||
f := &Factory{wgPort: wgPort}
|
||||
|
||||
if userspace {
|
||||
return f
|
||||
}
|
||||
|
||||
ebpfProxy := NewWGEBPFProxy(ctx, wgPort)
|
||||
err := ebpfProxy.listen()
|
||||
if err != nil {
|
||||
|
||||
@@ -4,6 +4,6 @@ package wgproxy
|
||||
|
||||
import "context"
|
||||
|
||||
func NewFactory(ctx context.Context, wgPort int) *Factory {
|
||||
func NewFactory(ctx context.Context, _ bool, wgPort int) *Factory {
|
||||
return &Factory{wgPort: wgPort}
|
||||
}
|
||||
|
||||
@@ -108,7 +108,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
store, cleanUp, err := server.NewTestStoreFromJson(config.Datadir)
|
||||
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@@ -119,13 +119,13 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
ia, _ := integrations.NewIntegratedValidator(eventStore)
|
||||
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -89,5 +89,9 @@ func _getInfo() string {
|
||||
func sysInfo() (serialNumber string, productName string, manufacturer string) {
|
||||
var si sysinfo.SysInfo
|
||||
si.GetSysInfo()
|
||||
return si.Chassis.Serial, si.Product.Name, si.Product.Vendor
|
||||
serial := si.Chassis.Serial
|
||||
if (serial == "Default string" || serial == "") && si.Product.Serial != "" {
|
||||
serial = si.Product.Serial
|
||||
}
|
||||
return serial, si.Product.Name, si.Product.Vendor
|
||||
}
|
||||
|
||||
@@ -80,6 +80,7 @@ func main() {
|
||||
log.Errorf("check PID file: %v", err)
|
||||
return
|
||||
}
|
||||
client.setDefaultFonts()
|
||||
systray.Run(client.onTrayReady, client.onTrayExit)
|
||||
}
|
||||
}
|
||||
@@ -876,3 +877,88 @@ func checkPIDFile() error {
|
||||
|
||||
return os.WriteFile(pidFile, []byte(fmt.Sprintf("%d", os.Getpid())), 0o664) //nolint:gosec
|
||||
}
|
||||
|
||||
func (s *serviceClient) setDefaultFonts() {
|
||||
var (
|
||||
defaultFontPath string
|
||||
)
|
||||
|
||||
//TODO: Linux Multiple Language Support
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
defaultFontPath = "/Library/Fonts/Arial Unicode.ttf"
|
||||
case "windows":
|
||||
fontPath := s.getWindowsFontFilePath()
|
||||
defaultFontPath = fontPath
|
||||
}
|
||||
|
||||
_, err := os.Stat(defaultFontPath)
|
||||
|
||||
if err == nil {
|
||||
os.Setenv("FYNE_FONT", defaultFontPath)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceClient) getWindowsFontFilePath() (fontPath string) {
|
||||
/*
|
||||
https://learn.microsoft.com/en-us/windows/apps/design/globalizing/loc-international-fonts
|
||||
https://learn.microsoft.com/en-us/typography/fonts/windows_11_font_list
|
||||
*/
|
||||
|
||||
var (
|
||||
fontFolder string = "C:/Windows/Fonts"
|
||||
fontMapping = map[string]string{
|
||||
"default": "Segoeui.ttf",
|
||||
"zh-CN": "Msyh.ttc",
|
||||
"am-ET": "Ebrima.ttf",
|
||||
"nirmala": "Nirmala.ttf",
|
||||
"chr-CHER-US": "Gadugi.ttf",
|
||||
"zh-HK": "Msjh.ttc",
|
||||
"zh-TW": "Msjh.ttc",
|
||||
"ja-JP": "Yugothm.ttc",
|
||||
"km-KH": "Leelawui.ttf",
|
||||
"ko-KR": "Malgun.ttf",
|
||||
"th-TH": "Leelawui.ttf",
|
||||
"ti-ET": "Ebrima.ttf",
|
||||
}
|
||||
nirMalaLang = []string{
|
||||
"as-IN",
|
||||
"bn-BD",
|
||||
"bn-IN",
|
||||
"gu-IN",
|
||||
"hi-IN",
|
||||
"kn-IN",
|
||||
"kok-IN",
|
||||
"ml-IN",
|
||||
"mr-IN",
|
||||
"ne-NP",
|
||||
"or-IN",
|
||||
"pa-IN",
|
||||
"si-LK",
|
||||
"ta-IN",
|
||||
"te-IN",
|
||||
}
|
||||
)
|
||||
cmd := exec.Command("powershell", "-Command", "(Get-Culture).Name")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get Windows default language setting: %v", err)
|
||||
fontPath = path.Join(fontFolder, fontMapping["default"])
|
||||
return
|
||||
}
|
||||
defaultLanguage := strings.TrimSpace(string(output))
|
||||
|
||||
for _, lang := range nirMalaLang {
|
||||
if defaultLanguage == lang {
|
||||
fontPath = path.Join(fontFolder, fontMapping["nirmala"])
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if font, ok := fontMapping[defaultLanguage]; ok {
|
||||
fontPath = path.Join(fontFolder, font)
|
||||
} else {
|
||||
fontPath = path.Join(fontFolder, fontMapping["default"])
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -17,28 +18,57 @@ import (
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
allRoutesText = "All routes"
|
||||
overlappingRoutesText = "Overlapping routes"
|
||||
exitNodeRoutesText = "Exit-node routes"
|
||||
allRoutes filter = "all"
|
||||
overlappingRoutes filter = "overlapping"
|
||||
exitNodeRoutes filter = "exit-node"
|
||||
getClientFMT = "get client: %v"
|
||||
)
|
||||
|
||||
type filter string
|
||||
|
||||
func (s *serviceClient) showRoutesUI() {
|
||||
s.wRoutes = s.app.NewWindow("NetBird Routes")
|
||||
|
||||
grid := container.New(layout.NewGridLayout(3))
|
||||
go s.updateRoutes(grid)
|
||||
allGrid := container.New(layout.NewGridLayout(3))
|
||||
go s.updateRoutes(allGrid, allRoutes)
|
||||
overlappingGrid := container.New(layout.NewGridLayout(3))
|
||||
exitNodeGrid := container.New(layout.NewGridLayout(3))
|
||||
routeCheckContainer := container.NewVBox()
|
||||
routeCheckContainer.Add(grid)
|
||||
tabs := container.NewAppTabs(
|
||||
container.NewTabItem(allRoutesText, allGrid),
|
||||
container.NewTabItem(overlappingRoutesText, overlappingGrid),
|
||||
container.NewTabItem(exitNodeRoutesText, exitNodeGrid),
|
||||
)
|
||||
tabs.OnSelected = func(item *container.TabItem) {
|
||||
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
|
||||
}
|
||||
tabs.OnUnselected = func(item *container.TabItem) {
|
||||
grid, _ := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
|
||||
grid.Objects = nil
|
||||
}
|
||||
|
||||
routeCheckContainer.Add(tabs)
|
||||
scrollContainer := container.NewVScroll(routeCheckContainer)
|
||||
scrollContainer.SetMinSize(fyne.NewSize(200, 300))
|
||||
|
||||
buttonBox := container.NewHBox(
|
||||
layout.NewSpacer(),
|
||||
widget.NewButton("Refresh", func() {
|
||||
s.updateRoutes(grid)
|
||||
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
|
||||
}),
|
||||
widget.NewButton("Select all", func() {
|
||||
s.selectAllRoutes()
|
||||
s.updateRoutes(grid)
|
||||
_, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
|
||||
s.selectAllFilteredRoutes(f)
|
||||
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
|
||||
}),
|
||||
widget.NewButton("Deselect All", func() {
|
||||
s.deselectAllRoutes()
|
||||
s.updateRoutes(grid)
|
||||
_, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
|
||||
s.deselectAllFilteredRoutes(f)
|
||||
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
|
||||
}),
|
||||
layout.NewSpacer(),
|
||||
)
|
||||
@@ -48,18 +78,12 @@ func (s *serviceClient) showRoutesUI() {
|
||||
s.wRoutes.SetContent(content)
|
||||
s.wRoutes.Show()
|
||||
|
||||
s.startAutoRefresh(5*time.Second, grid)
|
||||
s.startAutoRefresh(10*time.Second, tabs, allGrid, overlappingGrid, exitNodeGrid)
|
||||
}
|
||||
|
||||
func (s *serviceClient) updateRoutes(grid *fyne.Container) {
|
||||
routes, err := s.fetchRoutes()
|
||||
if err != nil {
|
||||
log.Errorf("get client: %v", err)
|
||||
s.showError(fmt.Errorf("get client: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
func (s *serviceClient) updateRoutes(grid *fyne.Container, f filter) {
|
||||
grid.Objects = nil
|
||||
grid.Refresh()
|
||||
idHeader := widget.NewLabelWithStyle(" ID", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
|
||||
networkHeader := widget.NewLabelWithStyle("Network/Domains", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
|
||||
resolvedIPsHeader := widget.NewLabelWithStyle("Resolved IPs", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
|
||||
@@ -67,7 +91,15 @@ func (s *serviceClient) updateRoutes(grid *fyne.Container) {
|
||||
grid.Add(idHeader)
|
||||
grid.Add(networkHeader)
|
||||
grid.Add(resolvedIPsHeader)
|
||||
for _, route := range routes {
|
||||
|
||||
filteredRoutes, err := s.getFilteredRoutes(f)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
sortRoutesByIDs(filteredRoutes)
|
||||
|
||||
for _, route := range filteredRoutes {
|
||||
r := route
|
||||
|
||||
checkBox := widget.NewCheck(r.GetID(), func(checked bool) {
|
||||
@@ -80,35 +112,104 @@ func (s *serviceClient) updateRoutes(grid *fyne.Container) {
|
||||
grid.Add(checkBox)
|
||||
network := r.GetNetwork()
|
||||
domains := r.GetDomains()
|
||||
if len(domains) > 0 {
|
||||
network = strings.Join(domains, ", ")
|
||||
}
|
||||
grid.Add(widget.NewLabel(network))
|
||||
|
||||
if len(domains) > 0 {
|
||||
var resolvedIPsList []string
|
||||
for _, domain := range r.GetDomains() {
|
||||
if ipList, exists := r.GetResolvedIPs()[domain]; exists {
|
||||
resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", ")))
|
||||
}
|
||||
}
|
||||
// TODO: limit width
|
||||
resolvedIPsLabel := widget.NewLabel(strings.Join(resolvedIPsList, ", "))
|
||||
grid.Add(resolvedIPsLabel)
|
||||
} else {
|
||||
if len(domains) == 0 {
|
||||
grid.Add(widget.NewLabel(network))
|
||||
grid.Add(widget.NewLabel(""))
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// our selectors are only for display
|
||||
noopFunc := func(_ string) {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
domainsSelector := widget.NewSelect(domains, noopFunc)
|
||||
domainsSelector.Selected = domains[0]
|
||||
grid.Add(domainsSelector)
|
||||
|
||||
var resolvedIPsList []string
|
||||
for _, domain := range domains {
|
||||
if ipList, exists := r.GetResolvedIPs()[domain]; exists {
|
||||
resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", ")))
|
||||
}
|
||||
}
|
||||
|
||||
if len(resolvedIPsList) == 0 {
|
||||
grid.Add(widget.NewLabel(""))
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO: limit width within the selector display
|
||||
resolvedIPsSelector := widget.NewSelect(resolvedIPsList, noopFunc)
|
||||
resolvedIPsSelector.Selected = resolvedIPsList[0]
|
||||
resolvedIPsSelector.Resize(fyne.NewSize(100, 100))
|
||||
grid.Add(resolvedIPsSelector)
|
||||
}
|
||||
|
||||
s.wRoutes.Content().Refresh()
|
||||
grid.Refresh()
|
||||
}
|
||||
|
||||
func (s *serviceClient) getFilteredRoutes(f filter) ([]*proto.Route, error) {
|
||||
routes, err := s.fetchRoutes()
|
||||
if err != nil {
|
||||
log.Errorf(getClientFMT, err)
|
||||
s.showError(fmt.Errorf(getClientFMT, err))
|
||||
return nil, err
|
||||
}
|
||||
switch f {
|
||||
case overlappingRoutes:
|
||||
return getOverlappingRoutes(routes), nil
|
||||
case exitNodeRoutes:
|
||||
return getExitNodeRoutes(routes), nil
|
||||
default:
|
||||
}
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func getOverlappingRoutes(routes []*proto.Route) []*proto.Route {
|
||||
var filteredRoutes []*proto.Route
|
||||
existingRange := make(map[string][]*proto.Route)
|
||||
for _, route := range routes {
|
||||
if len(route.Domains) > 0 {
|
||||
continue
|
||||
}
|
||||
if r, exists := existingRange[route.GetNetwork()]; exists {
|
||||
r = append(r, route)
|
||||
existingRange[route.GetNetwork()] = r
|
||||
} else {
|
||||
existingRange[route.GetNetwork()] = []*proto.Route{route}
|
||||
}
|
||||
}
|
||||
for _, r := range existingRange {
|
||||
if len(r) > 1 {
|
||||
filteredRoutes = append(filteredRoutes, r...)
|
||||
}
|
||||
}
|
||||
return filteredRoutes
|
||||
}
|
||||
|
||||
func getExitNodeRoutes(routes []*proto.Route) []*proto.Route {
|
||||
var filteredRoutes []*proto.Route
|
||||
for _, route := range routes {
|
||||
if route.Network == "0.0.0.0/0" {
|
||||
filteredRoutes = append(filteredRoutes, route)
|
||||
}
|
||||
}
|
||||
return filteredRoutes
|
||||
}
|
||||
|
||||
func sortRoutesByIDs(routes []*proto.Route) {
|
||||
sort.Slice(routes, func(i, j int) bool {
|
||||
return strings.ToLower(routes[i].GetID()) < strings.ToLower(routes[j].GetID())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) {
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get client: %v", err)
|
||||
return nil, fmt.Errorf(getClientFMT, err)
|
||||
}
|
||||
|
||||
resp, err := conn.ListRoutes(s.ctx, &proto.ListRoutesRequest{})
|
||||
@@ -122,8 +223,8 @@ func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) {
|
||||
func (s *serviceClient) selectRoute(id string, checked bool) {
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
log.Errorf("get client: %v", err)
|
||||
s.showError(fmt.Errorf("get client: %v", err))
|
||||
log.Errorf(getClientFMT, err)
|
||||
s.showError(fmt.Errorf(getClientFMT, err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -149,16 +250,14 @@ func (s *serviceClient) selectRoute(id string, checked bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceClient) selectAllRoutes() {
|
||||
func (s *serviceClient) selectAllFilteredRoutes(f filter) {
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
log.Errorf("get client: %v", err)
|
||||
log.Errorf(getClientFMT, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := &proto.SelectRoutesRequest{
|
||||
All: true,
|
||||
}
|
||||
req := s.getRoutesRequest(f, true)
|
||||
if _, err := conn.SelectRoutes(s.ctx, req); err != nil {
|
||||
log.Errorf("failed to select all routes: %v", err)
|
||||
s.showError(fmt.Errorf("failed to select all routes: %v", err))
|
||||
@@ -168,16 +267,14 @@ func (s *serviceClient) selectAllRoutes() {
|
||||
log.Debug("All routes selected")
|
||||
}
|
||||
|
||||
func (s *serviceClient) deselectAllRoutes() {
|
||||
func (s *serviceClient) deselectAllFilteredRoutes(f filter) {
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
log.Errorf("get client: %v", err)
|
||||
log.Errorf(getClientFMT, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := &proto.SelectRoutesRequest{
|
||||
All: true,
|
||||
}
|
||||
req := s.getRoutesRequest(f, false)
|
||||
if _, err := conn.DeselectRoutes(s.ctx, req); err != nil {
|
||||
log.Errorf("failed to deselect all routes: %v", err)
|
||||
s.showError(fmt.Errorf("failed to deselect all routes: %v", err))
|
||||
@@ -187,17 +284,34 @@ func (s *serviceClient) deselectAllRoutes() {
|
||||
log.Debug("All routes deselected")
|
||||
}
|
||||
|
||||
func (s *serviceClient) getRoutesRequest(f filter, appendRoute bool) *proto.SelectRoutesRequest {
|
||||
req := &proto.SelectRoutesRequest{}
|
||||
if f == allRoutes {
|
||||
req.All = true
|
||||
} else {
|
||||
routes, err := s.getFilteredRoutes(f)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
for _, route := range routes {
|
||||
req.RouteIDs = append(req.RouteIDs, route.GetID())
|
||||
}
|
||||
req.Append = appendRoute
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func (s *serviceClient) showError(err error) {
|
||||
wrappedMessage := wrapText(err.Error(), 50)
|
||||
|
||||
dialog.ShowError(fmt.Errorf("%s", wrappedMessage), s.wRoutes)
|
||||
}
|
||||
|
||||
func (s *serviceClient) startAutoRefresh(interval time.Duration, grid *fyne.Container) {
|
||||
func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) {
|
||||
ticker := time.NewTicker(interval)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
s.updateRoutes(grid)
|
||||
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodesGrid)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -206,6 +320,23 @@ func (s *serviceClient) startAutoRefresh(interval time.Duration, grid *fyne.Cont
|
||||
})
|
||||
}
|
||||
|
||||
func (s *serviceClient) updateRoutesBasedOnDisplayTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) {
|
||||
grid, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodesGrid)
|
||||
s.wRoutes.Content().Refresh()
|
||||
s.updateRoutes(grid, f)
|
||||
}
|
||||
|
||||
func getGridAndFilterFromTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) (*fyne.Container, filter) {
|
||||
switch tabs.Selected().Text {
|
||||
case overlappingRoutesText:
|
||||
return overlappingGrid, overlappingRoutes
|
||||
case exitNodeRoutesText:
|
||||
return exitNodesGrid, exitNodeRoutes
|
||||
default:
|
||||
return allGrid, allRoutes
|
||||
}
|
||||
}
|
||||
|
||||
// wrapText inserts newlines into the text to ensure that each line is
|
||||
// no longer than 'lineLength' runes.
|
||||
func wrapText(text string, lineLength int) string {
|
||||
|
||||
@@ -7,6 +7,18 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/context"
|
||||
)
|
||||
|
||||
type ExecutionContext string
|
||||
|
||||
const (
|
||||
ExecutionContextKey = "executionContext"
|
||||
|
||||
HTTPSource ExecutionContext = "HTTP"
|
||||
GRPCSource ExecutionContext = "GRPC"
|
||||
SystemSource ExecutionContext = "SYSTEM"
|
||||
)
|
||||
|
||||
// ContextHook is a custom hook for add the source information for the entry
|
||||
@@ -30,6 +42,27 @@ func (hook ContextHook) Levels() []logrus.Level {
|
||||
func (hook ContextHook) Fire(entry *logrus.Entry) error {
|
||||
src := hook.parseSrc(entry.Caller.File)
|
||||
entry.Data["source"] = fmt.Sprintf("%s:%v", src, entry.Caller.Line)
|
||||
|
||||
if entry.Context == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
source, ok := entry.Context.Value(ExecutionContextKey).(ExecutionContext)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
entry.Data["context"] = source
|
||||
|
||||
switch source {
|
||||
case HTTPSource:
|
||||
addHTTPFields(entry)
|
||||
case GRPCSource:
|
||||
addGRPCFields(entry)
|
||||
case SystemSource:
|
||||
addSystemFields(entry)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -59,3 +92,42 @@ func (hook ContextHook) parseSrc(filePath string) string {
|
||||
file := path.Base(filePath)
|
||||
return fmt.Sprintf("%s/%s", pkg, file)
|
||||
}
|
||||
|
||||
func addHTTPFields(entry *logrus.Entry) {
|
||||
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
|
||||
entry.Data[context.RequestIDKey] = ctxReqID
|
||||
}
|
||||
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
|
||||
entry.Data[context.AccountIDKey] = ctxAccountID
|
||||
}
|
||||
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
|
||||
entry.Data[context.UserIDKey] = ctxInitiatorID
|
||||
}
|
||||
}
|
||||
|
||||
func addGRPCFields(entry *logrus.Entry) {
|
||||
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
|
||||
entry.Data[context.RequestIDKey] = ctxReqID
|
||||
}
|
||||
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
|
||||
entry.Data[context.AccountIDKey] = ctxAccountID
|
||||
}
|
||||
if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok {
|
||||
entry.Data[context.PeerIDKey] = ctxDeviceID
|
||||
}
|
||||
}
|
||||
|
||||
func addSystemFields(entry *logrus.Entry) {
|
||||
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
|
||||
entry.Data[context.RequestIDKey] = ctxReqID
|
||||
}
|
||||
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
|
||||
entry.Data[context.UserIDKey] = ctxInitiatorID
|
||||
}
|
||||
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
|
||||
entry.Data[context.AccountIDKey] = ctxAccountID
|
||||
}
|
||||
if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok {
|
||||
entry.Data[context.PeerIDKey] = ctxDeviceID
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package formatter
|
||||
|
||||
import "github.com/sirupsen/logrus"
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// SetTextFormatter set the text formatter for given logger.
|
||||
func SetTextFormatter(logger *logrus.Logger) {
|
||||
@@ -9,6 +11,13 @@ func SetTextFormatter(logger *logrus.Logger) {
|
||||
logger.AddHook(NewContextHook())
|
||||
}
|
||||
|
||||
// SetJSONFormatter set the JSON formatter for given logger.
|
||||
func SetJSONFormatter(logger *logrus.Logger) {
|
||||
logger.Formatter = &logrus.JSONFormatter{}
|
||||
logger.ReportCaller = true
|
||||
logger.AddHook(NewContextHook())
|
||||
}
|
||||
|
||||
// SetLogcatFormatter set the logcat formatter for given logger.
|
||||
func SetLogcatFormatter(logger *logrus.Logger) {
|
||||
logger.Formatter = NewLogcatFormatter()
|
||||
|
||||
7
go.mod
7
go.mod
@@ -44,7 +44,6 @@ require (
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/google/go-cmp v0.6.0
|
||||
github.com/google/gopacket v1.1.19
|
||||
github.com/google/martian/v3 v3.0.0
|
||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
|
||||
github.com/gopacket/gopacket v1.1.1
|
||||
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
|
||||
@@ -58,7 +57,7 @@ require (
|
||||
github.com/miekg/dns v1.1.43
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/nadoo/ipset v0.5.0
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/oschwald/maxminddb-golang v1.12.0
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
@@ -189,8 +188,8 @@ require (
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.26.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.26.0 // indirect
|
||||
golang.org/x/image v0.10.0 // indirect
|
||||
golang.org/x/text v0.15.0 // indirect
|
||||
golang.org/x/image v0.18.0 // indirect
|
||||
golang.org/x/text v0.16.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect
|
||||
|
||||
16
go.sum
16
go.sum
@@ -209,8 +209,6 @@ github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSN
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
||||
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
||||
github.com/google/martian/v3 v3.0.0 h1:pMen7vLs8nvgEYhywH3KDWJIJTeEr2ULsVWHWYHQyBs=
|
||||
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
|
||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A=
|
||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc=
|
||||
github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o=
|
||||
@@ -335,8 +333,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd h1:IzGGIJMpz07aPs3R6/4sxZv63JoCMddftLpVodUK+Ec=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
|
||||
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
|
||||
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=
|
||||
@@ -542,8 +540,8 @@ golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJ
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
|
||||
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||
golang.org/x/image v0.0.0-20200430140353-33d19683fad8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
golang.org/x/image v0.10.0 h1:gXjUUtwtx5yOE0VKWq1CH4IJAClq4UGgUA3i+rpON9M=
|
||||
golang.org/x/image v0.10.0/go.mod h1:jtrku+n79PfroUbvDdeUWMAI+heR786BofxrbiSF+J0=
|
||||
golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ=
|
||||
golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E=
|
||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
|
||||
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
||||
@@ -565,7 +563,6 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20191004110552-13f9640d40b9/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
@@ -655,11 +652,10 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
|
||||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
|
||||
@@ -28,7 +28,11 @@ services:
|
||||
- LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL
|
||||
volumes:
|
||||
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt/
|
||||
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
# Signal
|
||||
signal:
|
||||
image: netbirdio/signal:$NETBIRD_SIGNAL_TAG
|
||||
@@ -40,6 +44,11 @@ services:
|
||||
# # port and command for Let's Encrypt validation
|
||||
# - 443:443
|
||||
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
# Management
|
||||
management:
|
||||
@@ -63,12 +72,16 @@ services:
|
||||
"--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN",
|
||||
"--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"
|
||||
]
|
||||
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
# Coturn
|
||||
coturn:
|
||||
image: coturn/coturn:$COTURN_TAG
|
||||
restart: unless-stopped
|
||||
domainname: $TURN_DOMAIN
|
||||
#domainname: $TURN_DOMAIN # only needed when TLS is enabled
|
||||
volumes:
|
||||
- ./turnserver.conf:/etc/turnserver.conf:ro
|
||||
# - ./privkey.pem:/etc/coturn/private/privkey.pem:ro
|
||||
@@ -76,7 +89,11 @@ services:
|
||||
network_mode: host
|
||||
command:
|
||||
- -c /etc/turnserver.conf
|
||||
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
volumes:
|
||||
$MGMT_VOLUMENAME:
|
||||
$SIGNAL_VOLUMENAME:
|
||||
|
||||
@@ -50,7 +50,7 @@ check_jq() {
|
||||
wait_crdb() {
|
||||
set +e
|
||||
while true; do
|
||||
if $DOCKER_COMPOSE_COMMAND exec -T crdb curl -sf -o /dev/null 'http://localhost:8080/health?ready=1'; then
|
||||
if $DOCKER_COMPOSE_COMMAND exec -T zdb curl -sf -o /dev/null 'http://localhost:8080/health?ready=1'; then
|
||||
break
|
||||
fi
|
||||
echo -n " ."
|
||||
@@ -61,14 +61,16 @@ wait_crdb() {
|
||||
}
|
||||
|
||||
init_crdb() {
|
||||
echo -e "\nInitializing Zitadel's CockroachDB\n\n"
|
||||
$DOCKER_COMPOSE_COMMAND up -d crdb
|
||||
echo ""
|
||||
# shellcheck disable=SC2028
|
||||
echo -n "Waiting cockroachDB to become ready "
|
||||
wait_crdb
|
||||
$DOCKER_COMPOSE_COMMAND exec -T crdb /bin/bash -c "cp /cockroach/certs/* /zitadel-certs/ && cockroach cert create-client --overwrite --certs-dir /zitadel-certs/ --ca-key /zitadel-certs/ca.key zitadel_user && chown -R 1000:1000 /zitadel-certs/"
|
||||
handle_request_command_status $? "init_crdb failed" ""
|
||||
if [[ $ZITADEL_DATABASE == "cockroach" ]]; then
|
||||
echo -e "\nInitializing Zitadel's CockroachDB\n\n"
|
||||
$DOCKER_COMPOSE_COMMAND up -d zdb
|
||||
echo ""
|
||||
# shellcheck disable=SC2028
|
||||
echo -n "Waiting CockroachDB to become ready"
|
||||
wait_crdb
|
||||
$DOCKER_COMPOSE_COMMAND exec -T zdb /bin/bash -c "cp /cockroach/certs/* /zitadel-certs/ && cockroach cert create-client --overwrite --certs-dir /zitadel-certs/ --ca-key /zitadel-certs/ca.key zitadel_user && chown -R 1000:1000 /zitadel-certs/"
|
||||
handle_request_command_status $? "init_crdb failed" ""
|
||||
fi
|
||||
}
|
||||
|
||||
get_main_ip_address() {
|
||||
@@ -156,7 +158,7 @@ create_new_application() {
|
||||
"'"$BASE_REDIRECT_URL2"'"
|
||||
],
|
||||
"postLogoutRedirectUris": [
|
||||
"'"$LOGOUT_URL"'"
|
||||
"'"$LOGOUT_URL"'"
|
||||
],
|
||||
"RESPONSETypes": [
|
||||
"OIDC_RESPONSE_TYPE_CODE"
|
||||
@@ -461,6 +463,20 @@ initEnvironment() {
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ $ZITADEL_DATABASE == "cockroach" ]]; then
|
||||
echo "Use CockroachDB as Zitadel database."
|
||||
ZDB=$(renderDockerComposeCockroachDB)
|
||||
ZITADEL_DB_ENV=$(renderZitadelCockroachDBEnv)
|
||||
else
|
||||
echo "Use Postgres as default Zitadel database."
|
||||
echo "For using CockroachDB please the environment variable 'export ZITADEL_DATABASE=cockroach'."
|
||||
POSTGRES_ROOT_PASSWORD="$(openssl rand -base64 32 | sed 's/=//g')@"
|
||||
POSTGRES_ZITADEL_PASSWORD="$(openssl rand -base64 32 | sed 's/=//g')@"
|
||||
ZDB=$(renderDockerComposePostgres)
|
||||
ZITADEL_DB_ENV=$(renderZitadelPostgresEnv)
|
||||
renderPostgresEnv > zdb.env
|
||||
fi
|
||||
|
||||
echo Rendering initial files...
|
||||
renderDockerCompose > docker-compose.yml
|
||||
renderCaddyfile > Caddyfile
|
||||
@@ -474,7 +490,7 @@ initEnvironment() {
|
||||
|
||||
init_crdb
|
||||
|
||||
echo -e "\nStarting Zidatel IDP for user management\n\n"
|
||||
echo -e "\nStarting Zitadel IDP for user management\n\n"
|
||||
$DOCKER_COMPOSE_COMMAND up -d caddy zitadel
|
||||
init_zitadel
|
||||
|
||||
@@ -634,15 +650,15 @@ renderManagementJson() {
|
||||
"ExtraConfig": {
|
||||
"ManagementEndpoint": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/management/v1"
|
||||
}
|
||||
},
|
||||
"DeviceAuthorizationFlow": {
|
||||
"Provider": "hosted",
|
||||
"ProviderConfig": {
|
||||
"Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI",
|
||||
"ClientID": "$NETBIRD_AUTH_CLIENT_ID_CLI",
|
||||
"Scope": "openid"
|
||||
}
|
||||
},
|
||||
},
|
||||
"DeviceAuthorizationFlow": {
|
||||
"Provider": "hosted",
|
||||
"ProviderConfig": {
|
||||
"Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI",
|
||||
"ClientID": "$NETBIRD_AUTH_CLIENT_ID_CLI",
|
||||
"Scope": "openid"
|
||||
}
|
||||
},
|
||||
"PKCEAuthorizationFlow": {
|
||||
"ProviderConfig": {
|
||||
"Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI",
|
||||
@@ -679,16 +695,6 @@ renderZitadelEnv() {
|
||||
cat <<EOF
|
||||
ZITADEL_LOG_LEVEL=debug
|
||||
ZITADEL_MASTERKEY=$ZITADEL_MASTERKEY
|
||||
ZITADEL_DATABASE_COCKROACH_HOST=crdb
|
||||
ZITADEL_DATABASE_COCKROACH_USER_USERNAME=zitadel_user
|
||||
ZITADEL_DATABASE_COCKROACH_USER_SSL_MODE=verify-full
|
||||
ZITADEL_DATABASE_COCKROACH_USER_SSL_ROOTCERT="/crdb-certs/ca.crt"
|
||||
ZITADEL_DATABASE_COCKROACH_USER_SSL_CERT="/crdb-certs/client.zitadel_user.crt"
|
||||
ZITADEL_DATABASE_COCKROACH_USER_SSL_KEY="/crdb-certs/client.zitadel_user.key"
|
||||
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_MODE=verify-full
|
||||
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_ROOTCERT="/crdb-certs/ca.crt"
|
||||
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_CERT="/crdb-certs/client.root.crt"
|
||||
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_KEY="/crdb-certs/client.root.key"
|
||||
ZITADEL_EXTERNALSECURE=$ZITADEL_EXTERNALSECURE
|
||||
ZITADEL_TLS_ENABLED="false"
|
||||
ZITADEL_EXTERNALPORT=$NETBIRD_PORT
|
||||
@@ -698,6 +704,43 @@ ZITADEL_FIRSTINSTANCE_ORG_MACHINE_MACHINE_USERNAME=zitadel-admin-sa
|
||||
ZITADEL_FIRSTINSTANCE_ORG_MACHINE_MACHINE_NAME=Admin
|
||||
ZITADEL_FIRSTINSTANCE_ORG_MACHINE_PAT_SCOPES=openid
|
||||
ZITADEL_FIRSTINSTANCE_ORG_MACHINE_PAT_EXPIRATIONDATE=$ZIDATE_TOKEN_EXPIRATION_DATE
|
||||
$ZITADEL_DB_ENV
|
||||
EOF
|
||||
}
|
||||
|
||||
renderZitadelCockroachDBEnv() {
|
||||
cat <<EOF
|
||||
ZITADEL_DATABASE_COCKROACH_HOST=zdb
|
||||
ZITADEL_DATABASE_COCKROACH_USER_USERNAME=zitadel_user
|
||||
ZITADEL_DATABASE_COCKROACH_USER_SSL_MODE=verify-full
|
||||
ZITADEL_DATABASE_COCKROACH_USER_SSL_ROOTCERT="/zdb-certs/ca.crt"
|
||||
ZITADEL_DATABASE_COCKROACH_USER_SSL_CERT="/zdb-certs/client.zitadel_user.crt"
|
||||
ZITADEL_DATABASE_COCKROACH_USER_SSL_KEY="/zdb-certs/client.zitadel_user.key"
|
||||
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_MODE=verify-full
|
||||
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_ROOTCERT="/zdb-certs/ca.crt"
|
||||
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_CERT="/zdb-certs/client.root.crt"
|
||||
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_KEY="/zdb-certs/client.root.key"
|
||||
EOF
|
||||
}
|
||||
|
||||
renderZitadelPostgresEnv() {
|
||||
cat <<EOF
|
||||
ZITADEL_DATABASE_POSTGRES_HOST=zdb
|
||||
ZITADEL_DATABASE_POSTGRES_PORT=5432
|
||||
ZITADEL_DATABASE_POSTGRES_DATABASE=zitadel
|
||||
ZITADEL_DATABASE_POSTGRES_USER_USERNAME=zitadel
|
||||
ZITADEL_DATABASE_POSTGRES_USER_PASSWORD=$POSTGRES_ZITADEL_PASSWORD
|
||||
ZITADEL_DATABASE_POSTGRES_USER_SSL_MODE=disable
|
||||
ZITADEL_DATABASE_POSTGRES_ADMIN_USERNAME=root
|
||||
ZITADEL_DATABASE_POSTGRES_ADMIN_PASSWORD=$POSTGRES_ROOT_PASSWORD
|
||||
ZITADEL_DATABASE_POSTGRES_ADMIN_SSL_MODE=disable
|
||||
EOF
|
||||
}
|
||||
|
||||
renderPostgresEnv() {
|
||||
cat <<EOF
|
||||
POSTGRES_USER=root
|
||||
POSTGRES_PASSWORD=$POSTGRES_ROOT_PASSWORD
|
||||
EOF
|
||||
}
|
||||
|
||||
@@ -717,18 +760,28 @@ services:
|
||||
volumes:
|
||||
- netbird_caddy_data:/data
|
||||
- ./Caddyfile:/etc/caddy/Caddyfile
|
||||
#UI dashboard
|
||||
# UI dashboard
|
||||
dashboard:
|
||||
image: netbirdio/dashboard:latest
|
||||
restart: unless-stopped
|
||||
networks: [netbird]
|
||||
env_file:
|
||||
- ./dashboard.env
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
# Signal
|
||||
signal:
|
||||
image: netbirdio/signal:latest
|
||||
restart: unless-stopped
|
||||
networks: [netbird]
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
# Management
|
||||
management:
|
||||
image: netbirdio/management:latest
|
||||
@@ -746,52 +799,49 @@ services:
|
||||
"--dns-domain=netbird.selfhosted",
|
||||
"--idp-sign-key-refresh-enabled",
|
||||
]
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
# Coturn, AKA relay server
|
||||
coturn:
|
||||
image: coturn/coturn
|
||||
restart: unless-stopped
|
||||
domainname: netbird.relay.selfhosted
|
||||
#domainname: netbird.relay.selfhosted
|
||||
volumes:
|
||||
- ./turnserver.conf:/etc/turnserver.conf:ro
|
||||
network_mode: host
|
||||
command:
|
||||
- -c /etc/turnserver.conf
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
# Zitadel - identity provider
|
||||
zitadel:
|
||||
restart: 'always'
|
||||
networks: [netbird]
|
||||
image: 'ghcr.io/zitadel/zitadel:v2.31.3'
|
||||
image: 'ghcr.io/zitadel/zitadel:v2.54.3'
|
||||
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
|
||||
env_file:
|
||||
- ./zitadel.env
|
||||
depends_on:
|
||||
crdb:
|
||||
zdb:
|
||||
condition: 'service_healthy'
|
||||
volumes:
|
||||
- ./machinekey:/machinekey
|
||||
- netbird_zitadel_certs:/crdb-certs:ro
|
||||
# CockroachDB for zitadel
|
||||
crdb:
|
||||
restart: 'always'
|
||||
networks: [netbird]
|
||||
image: 'cockroachdb/cockroach:v22.2.2'
|
||||
command: 'start-single-node --advertise-addr crdb'
|
||||
volumes:
|
||||
- netbird_crdb_data:/cockroach/cockroach-data
|
||||
- netbird_crdb_certs:/cockroach/certs
|
||||
- netbird_zitadel_certs:/zitadel-certs
|
||||
healthcheck:
|
||||
test: [ "CMD", "curl", "-f", "http://localhost:8080/health?ready=1" ]
|
||||
interval: '10s'
|
||||
timeout: '30s'
|
||||
retries: 5
|
||||
start_period: '20s'
|
||||
|
||||
volumes:
|
||||
- netbird_zitadel_certs:/zdb-certs:ro
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
$ZDB
|
||||
netbird_zdb_data:
|
||||
netbird_management:
|
||||
netbird_caddy_data:
|
||||
netbird_crdb_data:
|
||||
netbird_crdb_certs:
|
||||
netbird_zitadel_certs:
|
||||
|
||||
networks:
|
||||
@@ -799,4 +849,59 @@ networks:
|
||||
EOF
|
||||
}
|
||||
|
||||
renderDockerComposeCockroachDB() {
|
||||
cat <<EOF
|
||||
# CockroachDB for Zitadel
|
||||
zdb:
|
||||
restart: 'always'
|
||||
networks: [netbird]
|
||||
image: 'cockroachdb/cockroach:latest-v23.2'
|
||||
command: 'start-single-node --advertise-addr zdb'
|
||||
volumes:
|
||||
- netbird_zdb_data:/cockroach/cockroach-data
|
||||
- netbird_zdb_certs:/cockroach/certs
|
||||
- netbird_zitadel_certs:/zitadel-certs
|
||||
healthcheck:
|
||||
test: [ "CMD", "curl", "-f", "http://localhost:8080/health?ready=1" ]
|
||||
interval: '10s'
|
||||
timeout: '30s'
|
||||
retries: 5
|
||||
start_period: '20s'
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
volumes:
|
||||
netbird_zdb_certs:
|
||||
EOF
|
||||
}
|
||||
|
||||
renderDockerComposePostgres() {
|
||||
cat <<EOF
|
||||
# Postgres for Zitadel
|
||||
zdb:
|
||||
restart: 'always'
|
||||
networks: [netbird]
|
||||
image: 'postgres:16-alpine'
|
||||
env_file:
|
||||
- ./zdb.env
|
||||
volumes:
|
||||
- netbird_zdb_data:/var/lib/postgresql/data:rw
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready", "-d", "db_prod"]
|
||||
interval: 5s
|
||||
timeout: 60s
|
||||
retries: 10
|
||||
start_period: 5s
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
volumes:
|
||||
EOF
|
||||
}
|
||||
|
||||
initEnvironment
|
||||
|
||||
@@ -62,7 +62,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
store, cleanUp, err := mgmt.NewTestStoreFromJson(config.Datadir)
|
||||
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -70,13 +70,13 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
ia, _ := integrations.NewIntegratedValidator(eventStore)
|
||||
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -35,8 +36,10 @@ import (
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
httpapi "github.com/netbirdio/netbird/management/server/http"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
@@ -77,6 +80,10 @@ var (
|
||||
Short: "start NetBird Management Server",
|
||||
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
flag.Parse()
|
||||
|
||||
//nolint
|
||||
ctx := context.WithValue(cmd.Context(), formatter.ExecutionContextKey, formatter.SystemSource)
|
||||
|
||||
err := util.InitLog(logLevel, logFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
@@ -85,7 +92,7 @@ var (
|
||||
// detect whether user specified a port
|
||||
userPort := cmd.Flag("port").Changed
|
||||
|
||||
config, err = loadMgmtConfig(mgmtConfig)
|
||||
config, err = loadMgmtConfig(ctx, mgmtConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err)
|
||||
}
|
||||
@@ -116,6 +123,11 @@ var (
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, formatter.ExecutionContextKey, formatter.SystemSource)
|
||||
|
||||
err := handleRebrand(cmd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to migrate files %v", err)
|
||||
@@ -131,11 +143,11 @@ var (
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = appMetrics.Expose(mgmtMetricsPort, "/metrics")
|
||||
err = appMetrics.Expose(ctx, mgmtMetricsPort, "/metrics")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
store, err := server.NewStore(config.StoreConfig.Engine, config.Datadir, appMetrics)
|
||||
store, err := server.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err)
|
||||
}
|
||||
@@ -143,7 +155,7 @@ var (
|
||||
|
||||
var idpManager idp.Manager
|
||||
if config.IdpManagerConfig != nil {
|
||||
idpManager, err = idp.NewManager(*config.IdpManagerConfig, appMetrics)
|
||||
idpManager, err = idp.NewManager(ctx, *config.IdpManagerConfig, appMetrics)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed retrieving a new idp manager with err: %v", err)
|
||||
}
|
||||
@@ -152,32 +164,32 @@ var (
|
||||
if disableSingleAccMode {
|
||||
mgmtSingleAccModeDomain = ""
|
||||
}
|
||||
eventStore, key, err := integrations.InitEventStore(config.Datadir, config.DataStoreEncryptionKey)
|
||||
eventStore, key, err := integrations.InitEventStore(ctx, config.Datadir, config.DataStoreEncryptionKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize database: %s", err)
|
||||
}
|
||||
|
||||
if config.DataStoreEncryptionKey != key {
|
||||
log.Infof("update config with activity store key")
|
||||
log.WithContext(ctx).Infof("update config with activity store key")
|
||||
config.DataStoreEncryptionKey = key
|
||||
err := updateMgmtConfig(mgmtConfig, config)
|
||||
err := updateMgmtConfig(ctx, mgmtConfig, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write out store encryption key: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
geo, err := geolocation.NewGeolocation(config.Datadir)
|
||||
geo, err := geolocation.NewGeolocation(ctx, config.Datadir)
|
||||
if err != nil {
|
||||
log.Warnf("could not initialize geo location service: %v, we proceed without geo support", err)
|
||||
log.WithContext(ctx).Warnf("could not initialize geo location service: %v, we proceed without geo support", err)
|
||||
} else {
|
||||
log.Infof("geo location service has been initialized from %s", config.Datadir)
|
||||
log.WithContext(ctx).Infof("geo location service has been initialized from %s", config.Datadir)
|
||||
}
|
||||
|
||||
integratedPeerValidator, err := integrations.NewIntegratedValidator(eventStore)
|
||||
integratedPeerValidator, err := integrations.NewIntegratedValidator(ctx, eventStore)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize integrated peer validator: %v", err)
|
||||
}
|
||||
accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
|
||||
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
|
||||
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build default manager: %v", err)
|
||||
@@ -188,13 +200,13 @@ var (
|
||||
trustedPeers := config.ReverseProxy.TrustedPeers
|
||||
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
|
||||
if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) {
|
||||
log.Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
|
||||
log.WithContext(ctx).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
|
||||
trustedPeers = defaultTrustedPeers
|
||||
}
|
||||
trustedHTTPProxies := config.ReverseProxy.TrustedHTTPProxies
|
||||
trustedProxiesCount := config.ReverseProxy.TrustedHTTPProxiesCount
|
||||
if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 {
|
||||
log.Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
|
||||
log.WithContext(ctx).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
|
||||
"This is not recommended way to extract X-Forwarded-For. Consider using one of these options.")
|
||||
}
|
||||
realipOpts := []realip.Option{
|
||||
@@ -206,8 +218,8 @@ var (
|
||||
gRPCOpts := []grpc.ServerOption{
|
||||
grpc.KeepaliveEnforcementPolicy(kaep),
|
||||
grpc.KeepaliveParams(kasp),
|
||||
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...)),
|
||||
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...)),
|
||||
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor),
|
||||
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
|
||||
}
|
||||
|
||||
var certManager *autocert.Manager
|
||||
@@ -224,7 +236,7 @@ var (
|
||||
} else if config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "" {
|
||||
tlsConfig, err = loadTLSConfig(config.HttpConfig.CertFile, config.HttpConfig.CertKey)
|
||||
if err != nil {
|
||||
log.Errorf("cannot load TLS credentials: %v", err)
|
||||
log.WithContext(ctx).Errorf("cannot load TLS credentials: %v", err)
|
||||
return err
|
||||
}
|
||||
transportCredentials := credentials.NewTLS(tlsConfig)
|
||||
@@ -233,6 +245,7 @@ var (
|
||||
}
|
||||
|
||||
jwtValidator, err := jwtclaims.NewJWTValidator(
|
||||
ctx,
|
||||
config.HttpConfig.AuthIssuer,
|
||||
config.GetAuthAudiences(),
|
||||
config.HttpConfig.AuthKeysLocation,
|
||||
@@ -249,26 +262,24 @@ var (
|
||||
KeysLocation: config.HttpConfig.AuthKeysLocation,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating HTTP API handler: %v", err)
|
||||
}
|
||||
|
||||
ephemeralManager := server.NewEphemeralManager(store, accountManager)
|
||||
ephemeralManager.LoadInitialPeers()
|
||||
ephemeralManager.LoadInitialPeers(ctx)
|
||||
|
||||
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
||||
srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, appMetrics, ephemeralManager)
|
||||
srv, err := server.NewServer(ctx, config, accountManager, peersUpdateManager, turnManager, appMetrics, ephemeralManager)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating gRPC API handler: %v", err)
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
|
||||
|
||||
installationID, err := getInstallationID(store)
|
||||
installationID, err := getInstallationID(ctx, store)
|
||||
if err != nil {
|
||||
log.Errorf("cannot load TLS credentials: %v", err)
|
||||
log.WithContext(ctx).Errorf("cannot load TLS credentials: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -278,18 +289,18 @@ var (
|
||||
idpManager = config.IdpManagerConfig.ManagerType
|
||||
}
|
||||
metricsWorker := metrics.NewWorker(ctx, installationID, store, peersUpdateManager, idpManager)
|
||||
go metricsWorker.Run()
|
||||
go metricsWorker.Run(ctx)
|
||||
}
|
||||
|
||||
var compatListener net.Listener
|
||||
if mgmtPort != ManagementLegacyPort {
|
||||
// The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it
|
||||
// are using port 33073. For compatibility purposes we keep running a 2nd gRPC server on port 33073.
|
||||
compatListener, err = serveGRPC(gRPCAPIHandler, ManagementLegacyPort)
|
||||
compatListener, err = serveGRPC(ctx, gRPCAPIHandler, ManagementLegacyPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
|
||||
log.WithContext(ctx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
|
||||
}
|
||||
|
||||
rootHandler := handlerFunc(gRPCAPIHandler, httpAPIHandler)
|
||||
@@ -306,8 +317,8 @@ var (
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating TLS listener on port %d: %v", mgmtPort, err)
|
||||
}
|
||||
log.Infof("running HTTP server (LetsEncrypt challenge handler): %s", cml.Addr().String())
|
||||
serveHTTP(cml, certManager.HTTPHandler(nil))
|
||||
log.WithContext(ctx).Infof("running HTTP server (LetsEncrypt challenge handler): %s", cml.Addr().String())
|
||||
serveHTTP(ctx, cml, certManager.HTTPHandler(nil))
|
||||
}
|
||||
} else if tlsConfig != nil {
|
||||
listener, err = tls.Listen("tcp", fmt.Sprintf(":%d", mgmtPort), tlsConfig)
|
||||
@@ -321,14 +332,14 @@ var (
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("management server version %s", version.NetbirdVersion())
|
||||
log.Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String())
|
||||
serveGRPCWithHTTP(listener, rootHandler, tlsEnabled)
|
||||
log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion())
|
||||
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String())
|
||||
serveGRPCWithHTTP(ctx, listener, rootHandler, tlsEnabled)
|
||||
|
||||
SetupCloseHandler()
|
||||
|
||||
<-stopCh
|
||||
integratedPeerValidator.Stop()
|
||||
integratedPeerValidator.Stop(ctx)
|
||||
if geo != nil {
|
||||
_ = geo.Stop()
|
||||
}
|
||||
@@ -339,39 +350,68 @@ var (
|
||||
_ = certManager.Listener().Close()
|
||||
}
|
||||
gRPCAPIHandler.Stop()
|
||||
_ = store.Close()
|
||||
_ = eventStore.Close()
|
||||
log.Infof("stopped Management Service")
|
||||
_ = store.Close(ctx)
|
||||
_ = eventStore.Close(ctx)
|
||||
log.WithContext(ctx).Infof("stopped Management Service")
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func notifyStop(msg string) {
|
||||
func unaryInterceptor(
|
||||
ctx context.Context,
|
||||
req interface{},
|
||||
info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler,
|
||||
) (interface{}, error) {
|
||||
reqID := uuid.New().String()
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, formatter.ExecutionContextKey, formatter.GRPCSource)
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
func streamInterceptor(
|
||||
srv interface{},
|
||||
ss grpc.ServerStream,
|
||||
info *grpc.StreamServerInfo,
|
||||
handler grpc.StreamHandler,
|
||||
) error {
|
||||
reqID := uuid.New().String()
|
||||
wrapped := grpcMiddleware.WrapServerStream(ss)
|
||||
//nolint
|
||||
ctx := context.WithValue(ss.Context(), formatter.ExecutionContextKey, formatter.GRPCSource)
|
||||
//nolint
|
||||
wrapped.WrappedContext = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
|
||||
return handler(srv, wrapped)
|
||||
}
|
||||
|
||||
func notifyStop(ctx context.Context, msg string) {
|
||||
select {
|
||||
case stopCh <- 1:
|
||||
log.Error(msg)
|
||||
log.WithContext(ctx).Error(msg)
|
||||
default:
|
||||
// stop has been already called, nothing to report
|
||||
}
|
||||
}
|
||||
|
||||
func getInstallationID(store server.Store) (string, error) {
|
||||
func getInstallationID(ctx context.Context, store server.Store) (string, error) {
|
||||
installationID := store.GetInstallationID()
|
||||
if installationID != "" {
|
||||
return installationID, nil
|
||||
}
|
||||
|
||||
installationID = strings.ToUpper(uuid.New().String())
|
||||
err := store.SaveInstallationID(installationID)
|
||||
err := store.SaveInstallationID(ctx, installationID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return installationID, nil
|
||||
}
|
||||
|
||||
func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) {
|
||||
func serveGRPC(ctx context.Context, grpcServer *grpc.Server, port int) (net.Listener, error) {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -379,22 +419,22 @@ func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) {
|
||||
go func() {
|
||||
err := grpcServer.Serve(listener)
|
||||
if err != nil {
|
||||
notifyStop(fmt.Sprintf("failed running gRPC server on port %d: %v", port, err))
|
||||
notifyStop(ctx, fmt.Sprintf("failed running gRPC server on port %d: %v", port, err))
|
||||
}
|
||||
}()
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
func serveHTTP(httpListener net.Listener, handler http.Handler) {
|
||||
func serveHTTP(ctx context.Context, httpListener net.Listener, handler http.Handler) {
|
||||
go func() {
|
||||
err := http.Serve(httpListener, handler)
|
||||
if err != nil {
|
||||
notifyStop(fmt.Sprintf("failed running HTTP server: %v", err))
|
||||
notifyStop(ctx, fmt.Sprintf("failed running HTTP server: %v", err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func serveGRPCWithHTTP(listener net.Listener, handler http.Handler, tlsEnabled bool) {
|
||||
func serveGRPCWithHTTP(ctx context.Context, listener net.Listener, handler http.Handler, tlsEnabled bool) {
|
||||
go func() {
|
||||
var err error
|
||||
if tlsEnabled {
|
||||
@@ -411,7 +451,7 @@ func serveGRPCWithHTTP(listener net.Listener, handler http.Handler, tlsEnabled b
|
||||
if err != nil {
|
||||
select {
|
||||
case stopCh <- 1:
|
||||
log.Errorf("failed to serve HTTP and gRPC server: %v", err)
|
||||
log.WithContext(ctx).Errorf("failed to serve HTTP and gRPC server: %v", err)
|
||||
default:
|
||||
// stop has been already called, nothing to report
|
||||
}
|
||||
@@ -431,7 +471,7 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle
|
||||
})
|
||||
}
|
||||
|
||||
func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
|
||||
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) {
|
||||
loadedConfig := &server.Config{}
|
||||
_, err := util.ReadJson(mgmtConfigPath, loadedConfig)
|
||||
if err != nil {
|
||||
@@ -452,26 +492,26 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
|
||||
oidcEndpoint := loadedConfig.HttpConfig.OIDCConfigEndpoint
|
||||
if oidcEndpoint != "" {
|
||||
// if OIDCConfigEndpoint is specified, we can load DeviceAuthEndpoint and TokenEndpoint automatically
|
||||
log.Infof("loading OIDC configuration from the provided IDP configuration endpoint %s", oidcEndpoint)
|
||||
oidcConfig, err := fetchOIDCConfig(oidcEndpoint)
|
||||
log.WithContext(ctx).Infof("loading OIDC configuration from the provided IDP configuration endpoint %s", oidcEndpoint)
|
||||
oidcConfig, err := fetchOIDCConfig(ctx, oidcEndpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint)
|
||||
log.WithContext(ctx).Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint)
|
||||
|
||||
log.Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s",
|
||||
log.WithContext(ctx).Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s",
|
||||
oidcConfig.Issuer, loadedConfig.HttpConfig.AuthIssuer)
|
||||
loadedConfig.HttpConfig.AuthIssuer = oidcConfig.Issuer
|
||||
|
||||
log.Infof("overriding HttpConfig.AuthKeysLocation (JWT certs) with a new value %s, previously configured value: %s",
|
||||
log.WithContext(ctx).Infof("overriding HttpConfig.AuthKeysLocation (JWT certs) with a new value %s, previously configured value: %s",
|
||||
oidcConfig.JwksURI, loadedConfig.HttpConfig.AuthKeysLocation)
|
||||
loadedConfig.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
|
||||
|
||||
if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(server.NONE)) {
|
||||
log.Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
|
||||
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
|
||||
oidcConfig.TokenEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint)
|
||||
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
|
||||
log.Infof("overriding DeviceAuthorizationFlow.DeviceAuthEndpoint with a new value: %s, previously configured value: %s",
|
||||
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.DeviceAuthEndpoint with a new value: %s, previously configured value: %s",
|
||||
oidcConfig.DeviceAuthEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint)
|
||||
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint = oidcConfig.DeviceAuthEndpoint
|
||||
|
||||
@@ -479,7 +519,7 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s",
|
||||
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s",
|
||||
u.Host, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain)
|
||||
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host
|
||||
|
||||
@@ -489,10 +529,10 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
|
||||
}
|
||||
|
||||
if loadedConfig.PKCEAuthorizationFlow != nil {
|
||||
log.Infof("overriding PKCEAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
|
||||
log.WithContext(ctx).Infof("overriding PKCEAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
|
||||
oidcConfig.TokenEndpoint, loadedConfig.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint)
|
||||
loadedConfig.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
|
||||
log.Infof("overriding PKCEAuthorizationFlow.AuthorizationEndpoint with a new value: %s, previously configured value: %s",
|
||||
log.WithContext(ctx).Infof("overriding PKCEAuthorizationFlow.AuthorizationEndpoint with a new value: %s, previously configured value: %s",
|
||||
oidcConfig.AuthorizationEndpoint, loadedConfig.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint)
|
||||
loadedConfig.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint = oidcConfig.AuthorizationEndpoint
|
||||
}
|
||||
@@ -501,8 +541,8 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
|
||||
return loadedConfig, err
|
||||
}
|
||||
|
||||
func updateMgmtConfig(path string, config *server.Config) error {
|
||||
return util.DirectWriteJson(path, config)
|
||||
func updateMgmtConfig(ctx context.Context, path string, config *server.Config) error {
|
||||
return util.DirectWriteJson(ctx, path, config)
|
||||
}
|
||||
|
||||
// OIDCConfigResponse used for parsing OIDC config response
|
||||
@@ -515,7 +555,7 @@ type OIDCConfigResponse struct {
|
||||
}
|
||||
|
||||
// fetchOIDCConfig fetches OIDC configuration from the IDP
|
||||
func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) {
|
||||
func fetchOIDCConfig(ctx context.Context, oidcEndpoint string) (OIDCConfigResponse, error) {
|
||||
res, err := http.Get(oidcEndpoint)
|
||||
if err != nil {
|
||||
return OIDCConfigResponse{}, fmt.Errorf("failed fetching OIDC configuration from endpoint %s %v", oidcEndpoint, err)
|
||||
@@ -524,7 +564,7 @@ func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) {
|
||||
defer func() {
|
||||
err := res.Body.Close()
|
||||
if err != nil {
|
||||
log.Debugf("failed closing response body %v", err)
|
||||
log.WithContext(ctx).Debugf("failed closing response body %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var shortUp = "Migrate JSON file store to SQLite store. Please make a backup of the JSON file before running this command."
|
||||
@@ -26,10 +29,13 @@ var upCmd = &cobra.Command{
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
}
|
||||
|
||||
if err := server.MigrateFileStoreToSqlite(mgmtDataDir); err != nil {
|
||||
//nolint
|
||||
ctx := context.WithValue(cmd.Context(), formatter.ExecutionContextKey, formatter.SystemSource)
|
||||
|
||||
if err := server.MigrateFileStoreToSqlite(ctx, mgmtDataDir); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Info("Migration finished successfully")
|
||||
log.WithContext(ctx).Info("Migration finished successfully")
|
||||
|
||||
return nil
|
||||
},
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"encoding/json"
|
||||
@@ -29,11 +30,11 @@ import (
|
||||
type MocIntegratedValidator struct {
|
||||
}
|
||||
|
||||
func (a MocIntegratedValidator) ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
|
||||
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a MocIntegratedValidator) ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
|
||||
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
|
||||
return update, nil
|
||||
}
|
||||
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
|
||||
@@ -44,15 +45,15 @@ func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[s
|
||||
return validatedPeers, nil
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
|
||||
func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
|
||||
return peer
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
|
||||
func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
|
||||
return false, false, nil
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) PeerDeleted(_, _ string) error {
|
||||
func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -60,7 +61,7 @@ func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)
|
||||
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) Stop() {
|
||||
func (MocIntegratedValidator) Stop(_ context.Context) {
|
||||
}
|
||||
|
||||
func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) {
|
||||
@@ -85,7 +86,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Ac
|
||||
setupKey = key.Key
|
||||
}
|
||||
|
||||
_, _, err := manager.AddPeer(setupKey, userID, peer)
|
||||
_, _, _, err := manager.AddPeer(context.Background(), setupKey, userID, peer)
|
||||
if err != nil {
|
||||
t.Error("expected to add new peer successfully after creating new account, but failed", err)
|
||||
}
|
||||
@@ -395,7 +396,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, testCase := range tt {
|
||||
account := newAccountWithId("account-1", userID, "netbird.io")
|
||||
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io")
|
||||
account.UpdateSettings(&testCase.accountSettings)
|
||||
account.Network = network
|
||||
account.Peers = testCase.peers
|
||||
@@ -409,7 +410,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
validatedPeers[p] = struct{}{}
|
||||
}
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(testCase.peerID, "netbird.io", validatedPeers)
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, "netbird.io", validatedPeers)
|
||||
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
||||
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
||||
}
|
||||
@@ -419,7 +420,7 @@ func TestNewAccount(t *testing.T) {
|
||||
domain := "netbird.io"
|
||||
userId := "account_creator"
|
||||
accountID := "account_id"
|
||||
account := newAccountWithId(accountID, userId, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, userId, domain)
|
||||
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
|
||||
}
|
||||
|
||||
@@ -430,7 +431,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(userID, "")
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -439,7 +440,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err = manager.Store.GetAccountByUser(userID)
|
||||
account, err = manager.Store.GetAccountByUser(context.Background(), userID)
|
||||
if err != nil {
|
||||
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userID)
|
||||
return
|
||||
@@ -630,11 +631,11 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
initAccount, err := manager.GetAccountByUserOrAccountID(testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
|
||||
initAccount, err := manager.GetAccountByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
|
||||
require.NoError(t, err, "create init user failed")
|
||||
|
||||
if testCase.inputUpdateAttrs {
|
||||
err = manager.updateAccountDomainAttributes(initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
||||
err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
||||
require.NoError(t, err, "update init user failed")
|
||||
}
|
||||
|
||||
@@ -642,7 +643,7 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
||||
testCase.inputClaims.AccountId = initAccount.Id
|
||||
}
|
||||
|
||||
account, _, err := manager.GetAccountFromToken(testCase.inputClaims)
|
||||
account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims)
|
||||
require.NoError(t, err, "support function failed")
|
||||
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
|
||||
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
|
||||
@@ -661,12 +662,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
userId := "user-id"
|
||||
domain := "test.domain"
|
||||
|
||||
initAccount := newAccountWithId("", userId, domain)
|
||||
initAccount := newAccountWithId(context.Background(), "", userId, domain)
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID := initAccount.Id
|
||||
acc, err := manager.GetAccountByUserOrAccountID(userId, accountID, domain)
|
||||
acc, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, accountID, domain)
|
||||
require.NoError(t, err, "create init user failed")
|
||||
// as initAccount was created without account id we have to take the id after account initialization
|
||||
// that happens inside the GetAccountByUserOrAccountID where the id is getting generated
|
||||
@@ -682,18 +683,18 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("JWT groups disabled", func(t *testing.T) {
|
||||
account, _, err := manager.GetAccountFromToken(claims)
|
||||
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
require.Len(t, account.Groups, 1, "only ALL group should exists")
|
||||
})
|
||||
|
||||
t.Run("JWT groups enabled without claim name", func(t *testing.T) {
|
||||
initAccount.Settings.JWTGroupsEnabled = true
|
||||
err := manager.Store.SaveAccount(initAccount)
|
||||
err := manager.Store.SaveAccount(context.Background(), initAccount)
|
||||
require.NoError(t, err, "save account failed")
|
||||
require.Len(t, manager.Store.GetAllAccounts(), 1, "only one account should exist")
|
||||
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
||||
|
||||
account, _, err := manager.GetAccountFromToken(claims)
|
||||
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
|
||||
})
|
||||
@@ -701,11 +702,11 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
t.Run("JWT groups enabled", func(t *testing.T) {
|
||||
initAccount.Settings.JWTGroupsEnabled = true
|
||||
initAccount.Settings.JWTGroupsClaimName = "idp-groups"
|
||||
err := manager.Store.SaveAccount(initAccount)
|
||||
err := manager.Store.SaveAccount(context.Background(), initAccount)
|
||||
require.NoError(t, err, "save account failed")
|
||||
require.Len(t, manager.Store.GetAllAccounts(), 1, "only one account should exist")
|
||||
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
||||
|
||||
account, _, err := manager.GetAccountFromToken(claims)
|
||||
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
require.Len(t, account.Groups, 3, "groups should be added to the account")
|
||||
|
||||
@@ -728,7 +729,7 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
|
||||
func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId("account_id", "testuser", "")
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
|
||||
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
@@ -742,7 +743,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
err := store.SaveAccount(account)
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -751,7 +752,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||
Store: store,
|
||||
}
|
||||
|
||||
account, user, pat, err := am.GetAccountFromPAT(token)
|
||||
account, user, pat, err := am.GetAccountFromPAT(context.Background(), token)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting Account from PAT: %s", err)
|
||||
}
|
||||
@@ -763,7 +764,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||
|
||||
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId("account_id", "testuser", "")
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
|
||||
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
@@ -778,7 +779,7 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
err := store.SaveAccount(account)
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -787,12 +788,12 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
|
||||
Store: store,
|
||||
}
|
||||
|
||||
err = am.MarkPATUsed("tokenId")
|
||||
err = am.MarkPATUsed(context.Background(), "tokenId")
|
||||
if err != nil {
|
||||
t.Fatalf("Error when marking PAT used: %s", err)
|
||||
}
|
||||
|
||||
account, err = am.Store.GetAccount("account_id")
|
||||
account, err = am.Store.GetAccount(context.Background(), "account_id")
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting account: %s", err)
|
||||
}
|
||||
@@ -807,7 +808,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
userId := "test_user"
|
||||
account, err := manager.GetOrCreateAccountByUser(userId, "")
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -815,7 +816,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
|
||||
t.Fatalf("expected to create an account for a user %s", userId)
|
||||
}
|
||||
|
||||
account, err = manager.Store.GetAccountByUser(userId)
|
||||
account, err = manager.Store.GetAccountByUser(context.Background(), userId)
|
||||
if err != nil {
|
||||
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId)
|
||||
}
|
||||
@@ -834,7 +835,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
|
||||
|
||||
userId := "test_user"
|
||||
domain := "hotmail.com"
|
||||
account, err := manager.GetOrCreateAccountByUser(userId, domain)
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, domain)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -848,7 +849,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
|
||||
|
||||
domain = "gmail.com"
|
||||
|
||||
account, err = manager.GetOrCreateAccountByUser(userId, domain)
|
||||
account, err = manager.GetOrCreateAccountByUser(context.Background(), userId, domain)
|
||||
if err != nil {
|
||||
t.Fatalf("got the following error while retrieving existing acc: %v", err)
|
||||
}
|
||||
@@ -871,7 +872,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
|
||||
|
||||
userId := "test_user"
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(userId, "", "")
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, "", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -880,20 +881,20 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = manager.GetAccountByUserOrAccountID("", account.Id, "")
|
||||
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
|
||||
if err != nil {
|
||||
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id)
|
||||
}
|
||||
|
||||
_, err = manager.GetAccountByUserOrAccountID("", "", "")
|
||||
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", "", "")
|
||||
if err == nil {
|
||||
t.Errorf("expected an error when user and account IDs are empty")
|
||||
}
|
||||
}
|
||||
|
||||
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) {
|
||||
account := newAccountWithId(accountID, userID, domain)
|
||||
err := am.Store.SaveAccount(account)
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -915,7 +916,7 @@ func TestAccountManager_GetAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
// AddAccount has been already tested so we can assume it is correct and compare results
|
||||
getAccount, err := manager.Store.GetAccount(account.Id)
|
||||
getAccount, err := manager.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -952,12 +953,12 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.DeleteAccount(account.Id, userId)
|
||||
err = manager.DeleteAccount(context.Background(), account.Id, userId)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
getAccount, err := manager.Store.GetAccount(account.Id)
|
||||
getAccount, err := manager.Store.GetAccount(context.Background(), account.Id)
|
||||
if err == nil {
|
||||
t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount))
|
||||
}
|
||||
@@ -978,7 +979,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
||||
|
||||
serial := account.Network.CurrentSerial() // should be 0
|
||||
|
||||
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
@@ -997,7 +998,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
expectedSetupKey := setupKey.Key
|
||||
|
||||
peer, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
})
|
||||
@@ -1006,7 +1007,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err = manager.Store.GetAccount(account.Id)
|
||||
account, err = manager.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -1045,7 +1046,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(userID, "netbird.cloud")
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "netbird.cloud")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1065,7 +1066,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
expectedUserID := userID
|
||||
|
||||
peer, _, err := manager.AddPeer("", userID, &nbpeer.Peer{
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
})
|
||||
@@ -1074,7 +1075,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err = manager.Store.GetAccount(account.Id)
|
||||
account, err = manager.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -1121,7 +1122,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
@@ -1140,7 +1141,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
}
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
|
||||
peer, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
})
|
||||
@@ -1156,14 +1157,14 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
peer2 := getPeer()
|
||||
peer3 := getPeer()
|
||||
|
||||
account, err = manager.Store.GetAccount(account.Id)
|
||||
account, err = manager.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(peer1.ID)
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
group := group.Group{
|
||||
ID: "group-id",
|
||||
@@ -1197,7 +1198,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.SaveGroup(account.Id, userID, &group); err != nil {
|
||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1217,7 +1218,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.DeletePolicy(account.Id, account.Policies[0].ID, userID); err != nil {
|
||||
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1237,7 +1238,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.SavePolicy(account.Id, userID, &policy); err != nil {
|
||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1256,7 +1257,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.DeletePeer(account.Id, peer3.ID, userID); err != nil {
|
||||
if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil {
|
||||
t.Errorf("delete peer: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1277,9 +1278,9 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
}()
|
||||
|
||||
// clean policy is pre requirement for delete group
|
||||
_ = manager.DeletePolicy(account.Id, policy.ID, userID)
|
||||
_ = manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID)
|
||||
|
||||
if err := manager.DeleteGroup(account.Id, "", group.ID); err != nil {
|
||||
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
|
||||
t.Errorf("delete group: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1301,7 +1302,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
@@ -1315,7 +1316,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
|
||||
peerKey := key.PublicKey().String()
|
||||
|
||||
peer, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: peerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: peerKey},
|
||||
})
|
||||
@@ -1324,12 +1325,12 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
err = manager.DeletePeer(account.Id, peerKey, userID)
|
||||
err = manager.DeletePeer(context.Background(), account.Id, peerKey, userID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
account, err = manager.Store.GetAccount(account.Id)
|
||||
account, err = manager.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -1357,7 +1358,7 @@ func getEvent(t *testing.T, accountID string, manager AccountManager, eventType
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("no PeerAddedWithSetupKey event was generated")
|
||||
default:
|
||||
events, err := manager.GetEvents(accountID, userID)
|
||||
events, err := manager.GetEvents(context.Background(), accountID, userID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1389,7 +1390,7 @@ func TestGetUsersFromAccount(t *testing.T) {
|
||||
account.Users[user.Id] = user
|
||||
}
|
||||
|
||||
userInfos, err := manager.GetUsersFromAccount(accountId, "1")
|
||||
userInfos, err := manager.GetUsersFromAccount(context.Background(), accountId, "1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1500,7 +1501,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
routes := account.getRoutesToSync("peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
|
||||
routes := account.getRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
|
||||
|
||||
assert.Len(t, routes, 2)
|
||||
routeIDs := make(map[route.ID]struct{}, 2)
|
||||
@@ -1510,7 +1511,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
|
||||
assert.Contains(t, routeIDs, route.ID("route-2"))
|
||||
assert.Contains(t, routeIDs, route.ID("route-3"))
|
||||
|
||||
emptyRoutes := account.getRoutesToSync("peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
|
||||
emptyRoutes := account.getRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
|
||||
|
||||
assert.Len(t, emptyRoutes, 0)
|
||||
}
|
||||
@@ -1645,7 +1646,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
assert.NotNil(t, account.Settings)
|
||||
@@ -1657,23 +1658,23 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peer, _, err := manager.AddPeer("", userID, &nbpeer.Peer{
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
LoginExpirationEnabled: true,
|
||||
})
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
|
||||
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
})
|
||||
@@ -1682,10 +1683,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
manager.peerLoginExpiry = &MockScheduler{
|
||||
CancelFunc: func(IDs []string) {
|
||||
CancelFunc: func(ctx context.Context, IDs []string) {
|
||||
wg.Done()
|
||||
},
|
||||
ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
wg.Done()
|
||||
},
|
||||
}
|
||||
@@ -1693,11 +1694,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
// disable expiration first
|
||||
update := peer.Copy()
|
||||
update.LoginExpirationEnabled = false
|
||||
_, err = manager.UpdatePeer(account.Id, userID, update)
|
||||
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
|
||||
require.NoError(t, err, "unable to update peer")
|
||||
// enabling expiration should trigger the routine
|
||||
update.LoginExpirationEnabled = true
|
||||
_, err = manager.UpdatePeer(account.Id, userID, update)
|
||||
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
|
||||
require.NoError(t, err, "unable to update peer")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1710,18 +1711,18 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
_, _, err = manager.AddPeer("", userID, &nbpeer.Peer{
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
LoginExpirationEnabled: true,
|
||||
})
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
})
|
||||
@@ -1730,18 +1731,18 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
manager.peerLoginExpiry = &MockScheduler{
|
||||
CancelFunc: func(IDs []string) {
|
||||
CancelFunc: func(ctx context.Context, IDs []string) {
|
||||
wg.Done()
|
||||
},
|
||||
ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
wg.Done()
|
||||
},
|
||||
}
|
||||
|
||||
account, err = manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
account, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1754,35 +1755,35 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
_, _, err = manager.AddPeer("", userID, &nbpeer.Peer{
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
LoginExpirationEnabled: true,
|
||||
})
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
manager.peerLoginExpiry = &MockScheduler{
|
||||
CancelFunc: func(IDs []string) {
|
||||
CancelFunc: func(ctx context.Context, IDs []string) {
|
||||
wg.Done()
|
||||
},
|
||||
ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
wg.Done()
|
||||
},
|
||||
}
|
||||
// enabling PeerLoginExpirationEnabled should trigger the expiration job
|
||||
account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
|
||||
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
})
|
||||
@@ -1795,7 +1796,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
wg.Add(1)
|
||||
|
||||
// disabling PeerLoginExpirationEnabled should trigger cancel
|
||||
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
@@ -1810,10 +1811,10 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
updated, err := manager.UpdateAccountSettings(account.Id, userID, &Settings{
|
||||
updated, err := manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
@@ -1821,19 +1822,19 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
||||
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
|
||||
|
||||
account, err = manager.GetAccountByUserOrAccountID("", account.Id, "")
|
||||
account, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
|
||||
require.NoError(t, err, "unable to get account by ID")
|
||||
|
||||
assert.False(t, account.Settings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, account.Settings.PeerLoginExpiration, time.Hour)
|
||||
|
||||
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
PeerLoginExpiration: time.Second,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour * 24 * 181,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
@@ -2294,7 +2295,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
manager, err := BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -2305,7 +2306,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
func createStore(t *testing.T) (Store, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
store, cleanUp, err := NewTestStoreFromJson(dataDir)
|
||||
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -86,7 +87,7 @@ type Store struct {
|
||||
}
|
||||
|
||||
// NewSQLiteStore creates a new Store with an event table if not exists.
|
||||
func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) {
|
||||
func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (*Store, error) {
|
||||
dbFile := filepath.Join(dataDir, eventSinkDB)
|
||||
db, err := sql.Open("sqlite3", dbFile)
|
||||
if err != nil {
|
||||
@@ -111,7 +112,7 @@ func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = updateDeletedUsersTable(db)
|
||||
err = updateDeletedUsersTable(ctx, db)
|
||||
if err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
@@ -153,7 +154,7 @@ func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
|
||||
func (store *Store) processResult(ctx context.Context, result *sql.Rows) ([]*activity.Event, error) {
|
||||
events := make([]*activity.Event, 0)
|
||||
var cryptErr error
|
||||
for result.Next() {
|
||||
@@ -235,14 +236,14 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
|
||||
}
|
||||
|
||||
if cryptErr != nil {
|
||||
log.Warnf("%s", cryptErr)
|
||||
log.WithContext(ctx).Warnf("%s", cryptErr)
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
// Get returns "limit" number of events from index ordered descending or ascending by a timestamp
|
||||
func (store *Store) Get(accountID string, offset, limit int, descending bool) ([]*activity.Event, error) {
|
||||
func (store *Store) Get(ctx context.Context, accountID string, offset, limit int, descending bool) ([]*activity.Event, error) {
|
||||
stmt := store.selectDescStatement
|
||||
if !descending {
|
||||
stmt = store.selectAscStatement
|
||||
@@ -254,11 +255,11 @@ func (store *Store) Get(accountID string, offset, limit int, descending bool) ([
|
||||
}
|
||||
|
||||
defer result.Close() //nolint
|
||||
return store.processResult(result)
|
||||
return store.processResult(ctx, result)
|
||||
}
|
||||
|
||||
// Save an event in the SQLite events table end encrypt the "email" element in meta map
|
||||
func (store *Store) Save(event *activity.Event) (*activity.Event, error) {
|
||||
func (store *Store) Save(_ context.Context, event *activity.Event) (*activity.Event, error) {
|
||||
var jsonMeta string
|
||||
meta, err := store.saveDeletedUserEmailAndNameInEncrypted(event)
|
||||
if err != nil {
|
||||
@@ -317,15 +318,15 @@ func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event
|
||||
}
|
||||
|
||||
// Close the Store
|
||||
func (store *Store) Close() error {
|
||||
func (store *Store) Close(_ context.Context) error {
|
||||
if store.db != nil {
|
||||
return store.db.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateDeletedUsersTable(db *sql.DB) error {
|
||||
log.Debugf("check deleted_users table version")
|
||||
func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error {
|
||||
log.WithContext(ctx).Debugf("check deleted_users table version")
|
||||
rows, err := db.Query(`PRAGMA table_info(deleted_users);`)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -360,7 +361,7 @@ func updateDeletedUsersTable(db *sql.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debugf("update delted_users table")
|
||||
log.WithContext(ctx).Debugf("update delted_users table")
|
||||
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -13,17 +14,17 @@ import (
|
||||
func TestNewSQLiteStore(t *testing.T) {
|
||||
dataDir := t.TempDir()
|
||||
key, _ := GenerateKey()
|
||||
store, err := NewSQLiteStore(dataDir, key)
|
||||
store, err := NewSQLiteStore(context.Background(), dataDir, key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer store.Close() //nolint
|
||||
defer store.Close(context.Background()) //nolint
|
||||
|
||||
accountID := "account_1"
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err = store.Save(&activity.Event{
|
||||
_, err = store.Save(context.Background(), &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.PeerAddedByUser,
|
||||
InitiatorID: "user_" + fmt.Sprint(i),
|
||||
@@ -36,7 +37,7 @@ func TestNewSQLiteStore(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
result, err := store.Get(accountID, 0, 10, false)
|
||||
result, err := store.Get(context.Background(), accountID, 0, 10, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -45,7 +46,7 @@ func TestNewSQLiteStore(t *testing.T) {
|
||||
assert.Len(t, result, 10)
|
||||
assert.True(t, result[0].Timestamp.Before(result[len(result)-1].Timestamp))
|
||||
|
||||
result, err = store.Get(accountID, 0, 5, true)
|
||||
result, err = store.Get(context.Background(), accountID, 0, 5, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
package activity
|
||||
|
||||
import "sync"
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Store provides an interface to store or stream events.
|
||||
type Store interface {
|
||||
// Save an event in the store
|
||||
Save(event *Event) (*Event, error)
|
||||
Save(ctx context.Context, event *Event) (*Event, error)
|
||||
// Get returns "limit" number of events from the "offset" index ordered descending or ascending by a timestamp
|
||||
Get(accountID string, offset, limit int, descending bool) ([]*Event, error)
|
||||
Get(ctx context.Context, accountID string, offset, limit int, descending bool) ([]*Event, error)
|
||||
// Close the sink flushing events if necessary
|
||||
Close() error
|
||||
Close(ctx context.Context) error
|
||||
}
|
||||
|
||||
// InMemoryEventStore implements the Store interface storing data in-memory
|
||||
@@ -20,7 +23,7 @@ type InMemoryEventStore struct {
|
||||
}
|
||||
|
||||
// Save sets the Event.ID to 1
|
||||
func (store *InMemoryEventStore) Save(event *Event) (*Event, error) {
|
||||
func (store *InMemoryEventStore) Save(_ context.Context, event *Event) (*Event, error) {
|
||||
store.mu.Lock()
|
||||
defer store.mu.Unlock()
|
||||
if store.events == nil {
|
||||
@@ -33,7 +36,7 @@ func (store *InMemoryEventStore) Save(event *Event) (*Event, error) {
|
||||
}
|
||||
|
||||
// Get returns a list of ALL events that belong to the given accountID without taking offset, limit and order into consideration
|
||||
func (store *InMemoryEventStore) Get(accountID string, offset, limit int, descending bool) ([]*Event, error) {
|
||||
func (store *InMemoryEventStore) Get(_ context.Context, accountID string, offset, limit int, descending bool) ([]*Event, error) {
|
||||
store.mu.Lock()
|
||||
defer store.mu.Unlock()
|
||||
events := make([]*Event, 0)
|
||||
@@ -46,7 +49,7 @@ func (store *InMemoryEventStore) Get(accountID string, offset, limit int, descen
|
||||
}
|
||||
|
||||
// Close cleans up the event list
|
||||
func (store *InMemoryEventStore) Close() error {
|
||||
func (store *InMemoryEventStore) Close(_ context.Context) error {
|
||||
store.mu.Lock()
|
||||
defer store.mu.Unlock()
|
||||
store.events = make([]*Event, 0)
|
||||
|
||||
8
management/server/context/keys.go
Normal file
8
management/server/context/keys.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package context
|
||||
|
||||
const (
|
||||
RequestIDKey = "requestID"
|
||||
AccountIDKey = "accountID"
|
||||
UserIDKey = "userID"
|
||||
PeerIDKey = "peerID"
|
||||
)
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
@@ -34,11 +35,11 @@ func (d DNSSettings) Copy() DNSSettings {
|
||||
}
|
||||
|
||||
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
|
||||
func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string) (*DNSSettings, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -56,11 +57,11 @@ func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string)
|
||||
}
|
||||
|
||||
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
||||
func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -89,7 +90,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string
|
||||
account.DNSSettings = dnsSettingsToSave.Copy()
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(account); err != nil {
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -97,17 +98,18 @@ func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string
|
||||
for _, id := range addedGroups {
|
||||
group := account.GetGroup(id)
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
|
||||
}
|
||||
|
||||
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||
for _, id := range removedGroups {
|
||||
group := account.GetGroup(id)
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
||||
}
|
||||
|
||||
am.updateAccountPeers(account)
|
||||
// todo: check if before/after groups are in use by dns, acl, routes and if it has peers
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -149,9 +151,9 @@ func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig {
|
||||
return protoUpdate
|
||||
}
|
||||
|
||||
func getPeersCustomZone(account *Account, dnsDomain string) nbdns.CustomZone {
|
||||
func getPeersCustomZone(ctx context.Context, account *Account, dnsDomain string) nbdns.CustomZone {
|
||||
if dnsDomain == "" {
|
||||
log.Errorf("no dns domain is set, returning empty zone")
|
||||
log.WithContext(ctx).Errorf("no dns domain is set, returning empty zone")
|
||||
return nbdns.CustomZone{}
|
||||
}
|
||||
|
||||
@@ -161,7 +163,7 @@ func getPeersCustomZone(account *Account, dnsDomain string) nbdns.CustomZone {
|
||||
|
||||
for _, peer := range account.Peers {
|
||||
if peer.DNSLabel == "" {
|
||||
log.Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name)
|
||||
log.WithContext(ctx).Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -210,14 +212,14 @@ func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func addPeerLabelsToAccount(account *Account, peerLabels lookupMap) {
|
||||
func addPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels lookupMap) {
|
||||
for _, peer := range account.Peers {
|
||||
label, err := getPeerHostLabel(peer.Name, peerLabels)
|
||||
if err != nil {
|
||||
log.Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err)
|
||||
log.WithContext(ctx).Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err)
|
||||
label, err = getPeerHostLabel(peer.Meta.Hostname, peerLabels)
|
||||
if err != nil {
|
||||
log.Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err)
|
||||
log.WithContext(ctx).Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
@@ -35,7 +36,7 @@ func TestGetDNSSettings(t *testing.T) {
|
||||
t.Fatal("failed to init testing account")
|
||||
}
|
||||
|
||||
dnsSettings, err := am.GetDNSSettings(account.Id, dnsAdminUserID)
|
||||
dnsSettings, err := am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID)
|
||||
if err != nil {
|
||||
t.Fatalf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err)
|
||||
}
|
||||
@@ -48,12 +49,12 @@ func TestGetDNSSettings(t *testing.T) {
|
||||
DisabledManagementGroups: []string{group1ID},
|
||||
}
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Error("failed to save testing account with new DNS settings")
|
||||
}
|
||||
|
||||
dnsSettings, err = am.GetDNSSettings(account.Id, dnsAdminUserID)
|
||||
dnsSettings, err = am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID)
|
||||
if err != nil {
|
||||
t.Errorf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err)
|
||||
}
|
||||
@@ -62,7 +63,7 @@ func TestGetDNSSettings(t *testing.T) {
|
||||
t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups)
|
||||
}
|
||||
|
||||
_, err = am.GetDNSSettings(account.Id, dnsRegularUserID)
|
||||
_, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID)
|
||||
if err == nil {
|
||||
t.Errorf("An error should be returned when getting the DNS settings with a regular user")
|
||||
}
|
||||
@@ -122,7 +123,7 @@ func TestSaveDNSSettings(t *testing.T) {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
err = am.SaveDNSSettings(account.Id, testCase.userID, testCase.inputSettings)
|
||||
err = am.SaveDNSSettings(context.Background(), account.Id, testCase.userID, testCase.inputSettings)
|
||||
if err != nil {
|
||||
if testCase.shouldFail {
|
||||
return
|
||||
@@ -130,7 +131,7 @@ func TestSaveDNSSettings(t *testing.T) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
updatedAccount, err := am.Store.GetAccount(account.Id)
|
||||
updatedAccount, err := am.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Errorf("should be able to retrieve updated account, got err: %s", err)
|
||||
}
|
||||
@@ -164,7 +165,7 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
newAccountDNSConfig, err := am.GetNetworkMap(peer1.ID)
|
||||
newAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS config should have one custom zone for peers")
|
||||
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled")
|
||||
@@ -173,14 +174,14 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
|
||||
dnsSettings := account.DNSSettings.Copy()
|
||||
dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID)
|
||||
account.DNSSettings = dnsSettings
|
||||
err = am.Store.SaveAccount(account)
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedAccountDNSConfig, err := am.GetNetworkMap(peer1.ID)
|
||||
updatedAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS config should have no custom zone when peer belongs to a disabled group")
|
||||
require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS config should have local DNS service disabled when peer belongs to a disabled group")
|
||||
peer2AccountDNSConfig, err := am.GetNetworkMap(peer2.ID)
|
||||
peer2AccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer2.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS config should have one custom zone for peers not in the disabled group")
|
||||
require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS config should have DNS service enabled for peers not in the disabled group")
|
||||
@@ -194,13 +195,13 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
return nil, err
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{})
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{})
|
||||
}
|
||||
|
||||
func createDNSStore(t *testing.T) (Store, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
store, cleanUp, err := NewTestStoreFromJson(dataDir)
|
||||
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -244,28 +245,28 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
|
||||
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(dnsAccountID, dnsAdminUserID, domain)
|
||||
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain)
|
||||
|
||||
account.Users[dnsRegularUserID] = &User{
|
||||
Id: dnsRegularUserID,
|
||||
Role: UserRoleUser,
|
||||
}
|
||||
|
||||
err := am.Store.SaveAccount(account)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
savedPeer1, _, err := am.AddPeer("", dnsAdminUserID, peer1)
|
||||
savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, _, err = am.AddPeer("", dnsAdminUserID, peer2)
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
account, err = am.Store.GetAccount(account.Id)
|
||||
account, err = am.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -312,10 +313,10 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
|
||||
Groups: []string{allGroup.ID},
|
||||
}
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return am.Store.GetAccount(account.Id)
|
||||
return am.Store.GetAccount(context.Background(), account.Id)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -51,13 +52,15 @@ func NewEphemeralManager(store Store, accountManager AccountManager) *EphemeralM
|
||||
// LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head
|
||||
// of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new
|
||||
// head.
|
||||
func (e *EphemeralManager) LoadInitialPeers() {
|
||||
func (e *EphemeralManager) LoadInitialPeers(ctx context.Context) {
|
||||
e.peersLock.Lock()
|
||||
defer e.peersLock.Unlock()
|
||||
|
||||
e.loadEphemeralPeers()
|
||||
e.loadEphemeralPeers(ctx)
|
||||
if e.headPeer != nil {
|
||||
e.timer = time.AfterFunc(ephemeralLifeTime, e.cleanup)
|
||||
e.timer = time.AfterFunc(ephemeralLifeTime, func() {
|
||||
e.cleanup(ctx)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,12 +76,12 @@ func (e *EphemeralManager) Stop() {
|
||||
|
||||
// OnPeerConnected remove the peer from the linked list of ephemeral peers. Because it has been called when the peer
|
||||
// is active the manager will not delete it while it is active.
|
||||
func (e *EphemeralManager) OnPeerConnected(peer *nbpeer.Peer) {
|
||||
func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Peer) {
|
||||
if !peer.Ephemeral {
|
||||
return
|
||||
}
|
||||
|
||||
log.Tracef("remove peer from ephemeral list: %s", peer.ID)
|
||||
log.WithContext(ctx).Tracef("remove peer from ephemeral list: %s", peer.ID)
|
||||
|
||||
e.peersLock.Lock()
|
||||
defer e.peersLock.Unlock()
|
||||
@@ -94,16 +97,16 @@ func (e *EphemeralManager) OnPeerConnected(peer *nbpeer.Peer) {
|
||||
|
||||
// OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer
|
||||
// is inactive it will be deleted after the ephemeralLifeTime period.
|
||||
func (e *EphemeralManager) OnPeerDisconnected(peer *nbpeer.Peer) {
|
||||
func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) {
|
||||
if !peer.Ephemeral {
|
||||
return
|
||||
}
|
||||
|
||||
log.Tracef("add peer to ephemeral list: %s", peer.ID)
|
||||
log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID)
|
||||
|
||||
a, err := e.store.GetAccountByPeerID(peer.ID)
|
||||
a, err := e.store.GetAccountByPeerID(context.Background(), peer.ID)
|
||||
if err != nil {
|
||||
log.Errorf("failed to add peer to ephemeral list: %s", err)
|
||||
log.WithContext(ctx).Errorf("failed to add peer to ephemeral list: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -116,12 +119,14 @@ func (e *EphemeralManager) OnPeerDisconnected(peer *nbpeer.Peer) {
|
||||
|
||||
e.addPeer(peer.ID, a, newDeadLine())
|
||||
if e.timer == nil {
|
||||
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), e.cleanup)
|
||||
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() {
|
||||
e.cleanup(ctx)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) loadEphemeralPeers() {
|
||||
accounts := e.store.GetAllAccounts()
|
||||
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
|
||||
accounts := e.store.GetAllAccounts(context.Background())
|
||||
t := newDeadLine()
|
||||
count := 0
|
||||
for _, a := range accounts {
|
||||
@@ -132,10 +137,10 @@ func (e *EphemeralManager) loadEphemeralPeers() {
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Debugf("loaded ephemeral peer(s): %d", count)
|
||||
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count)
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) cleanup() {
|
||||
func (e *EphemeralManager) cleanup(ctx context.Context) {
|
||||
log.Tracef("on ephemeral cleanup")
|
||||
deletePeers := make(map[string]*ephemeralPeer)
|
||||
|
||||
@@ -154,7 +159,9 @@ func (e *EphemeralManager) cleanup() {
|
||||
}
|
||||
|
||||
if e.headPeer != nil {
|
||||
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), e.cleanup)
|
||||
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() {
|
||||
e.cleanup(ctx)
|
||||
})
|
||||
} else {
|
||||
e.timer = nil
|
||||
}
|
||||
@@ -162,10 +169,10 @@ func (e *EphemeralManager) cleanup() {
|
||||
e.peersLock.Unlock()
|
||||
|
||||
for id, p := range deletePeers {
|
||||
log.Debugf("delete ephemeral peer: %s", id)
|
||||
err := e.accountManager.DeletePeer(p.account.Id, id, activity.SystemInitiator)
|
||||
log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id)
|
||||
err := e.accountManager.DeletePeer(ctx, p.account.Id, id, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete ephemeral peer: %s", err)
|
||||
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -13,11 +14,11 @@ type MockStore struct {
|
||||
account *Account
|
||||
}
|
||||
|
||||
func (s *MockStore) GetAllAccounts() []*Account {
|
||||
func (s *MockStore) GetAllAccounts(_ context.Context) []*Account {
|
||||
return []*Account{s.account}
|
||||
}
|
||||
|
||||
func (s *MockStore) GetAccountByPeerID(peerId string) (*Account, error) {
|
||||
func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Account, error) {
|
||||
_, ok := s.account.Peers[peerId]
|
||||
if ok {
|
||||
return s.account, nil
|
||||
@@ -31,7 +32,7 @@ type MocAccountManager struct {
|
||||
store *MockStore
|
||||
}
|
||||
|
||||
func (a MocAccountManager) DeletePeer(accountID, peerID, userID string) error {
|
||||
func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error {
|
||||
delete(a.store.account.Peers, peerID)
|
||||
return nil //nolint:nil
|
||||
}
|
||||
@@ -52,9 +53,9 @@ func TestNewManager(t *testing.T) {
|
||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
|
||||
mgr := NewEphemeralManager(store, am)
|
||||
mgr.loadEphemeralPeers()
|
||||
mgr.loadEphemeralPeers(context.Background())
|
||||
startTime = startTime.Add(ephemeralLifeTime + 1)
|
||||
mgr.cleanup()
|
||||
mgr.cleanup(context.Background())
|
||||
|
||||
if len(store.account.Peers) != numberOfPeers {
|
||||
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers))
|
||||
@@ -77,11 +78,11 @@ func TestNewManagerPeerConnected(t *testing.T) {
|
||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
|
||||
mgr := NewEphemeralManager(store, am)
|
||||
mgr.loadEphemeralPeers()
|
||||
mgr.OnPeerConnected(store.account.Peers["ephemeral_peer_0"])
|
||||
mgr.loadEphemeralPeers(context.Background())
|
||||
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
|
||||
|
||||
startTime = startTime.Add(ephemeralLifeTime + 1)
|
||||
mgr.cleanup()
|
||||
mgr.cleanup(context.Background())
|
||||
|
||||
expected := numberOfPeers + 1
|
||||
if len(store.account.Peers) != expected {
|
||||
@@ -105,15 +106,15 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
|
||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
|
||||
mgr := NewEphemeralManager(store, am)
|
||||
mgr.loadEphemeralPeers()
|
||||
mgr.loadEphemeralPeers(context.Background())
|
||||
for _, v := range store.account.Peers {
|
||||
mgr.OnPeerConnected(v)
|
||||
mgr.OnPeerConnected(context.Background(), v)
|
||||
|
||||
}
|
||||
mgr.OnPeerDisconnected(store.account.Peers["ephemeral_peer_0"])
|
||||
mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
|
||||
|
||||
startTime = startTime.Add(ephemeralLifeTime + 1)
|
||||
mgr.cleanup()
|
||||
mgr.cleanup(context.Background())
|
||||
|
||||
expected := numberOfPeers + numberOfEphemeralPeers - 1
|
||||
if len(store.account.Peers) != expected {
|
||||
@@ -122,7 +123,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
|
||||
}
|
||||
|
||||
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) {
|
||||
store.account = newAccountWithId("my account", "", "")
|
||||
store.account = newAccountWithId(context.Background(), "my account", "", "")
|
||||
|
||||
for i := 0; i < numberOfPeers; i++ {
|
||||
peerId := fmt.Sprintf("peer_%d", i)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -11,11 +12,11 @@ import (
|
||||
)
|
||||
|
||||
// GetEvents returns a list of activity events of an account
|
||||
func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activity.Event, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -29,7 +30,7 @@ func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activit
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view events")
|
||||
}
|
||||
|
||||
events, err := am.eventStore.Get(accountID, 0, 10000, true)
|
||||
events, err := am.eventStore.Get(ctx, accountID, 0, 10000, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -54,10 +55,10 @@ func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activit
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
func (am *DefaultAccountManager) StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
|
||||
go func() {
|
||||
_, err := am.eventStore.Save(&activity.Event{
|
||||
_, err := am.eventStore.Save(ctx, &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activityID,
|
||||
InitiatorID: initiatorID,
|
||||
@@ -67,7 +68,7 @@ func (am *DefaultAccountManager) StoreEvent(initiatorID, targetID, accountID str
|
||||
})
|
||||
if err != nil {
|
||||
// todo add metric
|
||||
log.Errorf("received an error while storing an activity event, error: %s", err)
|
||||
log.WithContext(ctx).Errorf("received an error while storing an activity event, error: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -13,7 +14,7 @@ func generateAndStoreEvents(t *testing.T, manager *DefaultAccountManager, typ ac
|
||||
accountID string, count int) {
|
||||
t.Helper()
|
||||
for i := 0; i < count; i++ {
|
||||
_, err := manager.eventStore.Save(&activity.Event{
|
||||
_, err := manager.eventStore.Save(context.Background(), &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: typ,
|
||||
InitiatorID: initiatorID,
|
||||
@@ -35,32 +36,32 @@ func TestDefaultAccountManager_GetEvents(t *testing.T) {
|
||||
accountID := "accountID"
|
||||
|
||||
t.Run("get empty events list", func(t *testing.T) {
|
||||
events, err := manager.GetEvents(accountID, userID)
|
||||
events, err := manager.GetEvents(context.Background(), accountID, userID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
assert.Len(t, events, 0)
|
||||
_ = manager.eventStore.Close() //nolint
|
||||
_ = manager.eventStore.Close(context.Background()) //nolint
|
||||
})
|
||||
|
||||
t.Run("get events", func(t *testing.T) {
|
||||
generateAndStoreEvents(t, manager, activity.PeerAddedByUser, userID, "peer", accountID, 10)
|
||||
events, err := manager.GetEvents(accountID, userID)
|
||||
events, err := manager.GetEvents(context.Background(), accountID, userID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Len(t, events, 10)
|
||||
_ = manager.eventStore.Close() //nolint
|
||||
_ = manager.eventStore.Close(context.Background()) //nolint
|
||||
})
|
||||
|
||||
t.Run("get events without duplicates", func(t *testing.T) {
|
||||
generateAndStoreEvents(t, manager, activity.UserJoined, userID, "", accountID, 10)
|
||||
events, err := manager.GetEvents(accountID, userID)
|
||||
events, err := manager.GetEvents(context.Background(), accountID, userID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
assert.Len(t, events, 1)
|
||||
_ = manager.eventStore.Close() //nolint
|
||||
_ = manager.eventStore.Close(context.Background()) //nolint
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -48,8 +49,8 @@ type FileStore struct {
|
||||
type StoredAccount struct{}
|
||||
|
||||
// NewFileStore restores a store from the file located in the datadir
|
||||
func NewFileStore(dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
|
||||
fs, err := restore(filepath.Join(dataDir, storeFileName))
|
||||
func NewFileStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
|
||||
fs, err := restore(ctx, filepath.Join(dataDir, storeFileName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -58,27 +59,27 @@ func NewFileStore(dataDir string, metrics telemetry.AppMetrics) (*FileStore, err
|
||||
}
|
||||
|
||||
// NewFilestoreFromSqliteStore restores a store from Sqlite and stores to Filestore json in the file located in datadir
|
||||
func NewFilestoreFromSqliteStore(sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
|
||||
store, err := NewFileStore(dataDir, metrics)
|
||||
func NewFilestoreFromSqliteStore(ctx context.Context, sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
|
||||
store, err := NewFileStore(ctx, dataDir, metrics)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = store.SaveInstallationID(sqlStore.GetInstallationID())
|
||||
err = store.SaveInstallationID(ctx, sqlStore.GetInstallationID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, account := range sqlStore.GetAllAccounts() {
|
||||
for _, account := range sqlStore.GetAllAccounts(ctx) {
|
||||
store.Accounts[account.Id] = account
|
||||
}
|
||||
|
||||
return store, store.persist(store.storeFile)
|
||||
return store, store.persist(ctx, store.storeFile)
|
||||
}
|
||||
|
||||
// restore the state of the store from the file.
|
||||
// Creates a new empty store file if doesn't exist
|
||||
func restore(file string) (*FileStore, error) {
|
||||
func restore(ctx context.Context, file string) (*FileStore, error) {
|
||||
if _, err := os.Stat(file); os.IsNotExist(err) {
|
||||
// create a new FileStore if previously didn't exist (e.g. first run)
|
||||
s := &FileStore{
|
||||
@@ -95,7 +96,7 @@ func restore(file string) (*FileStore, error) {
|
||||
storeFile: file,
|
||||
}
|
||||
|
||||
err = s.persist(file)
|
||||
err = s.persist(ctx, file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -165,7 +166,7 @@ func restore(file string) (*FileStore, error) {
|
||||
// for data migration. Can be removed once most base will be with labels
|
||||
existingLabels := account.getPeerDNSLabels()
|
||||
if len(existingLabels) != len(account.Peers) {
|
||||
addPeerLabelsToAccount(account, existingLabels)
|
||||
addPeerLabelsToAccount(ctx, account, existingLabels)
|
||||
}
|
||||
|
||||
// TODO: delete this block after migration
|
||||
@@ -178,7 +179,7 @@ func restore(file string) (*FileStore, error) {
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
log.Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err)
|
||||
log.WithContext(ctx).Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err)
|
||||
// if the All group didn't exist we probably don't have routes to update
|
||||
continue
|
||||
}
|
||||
@@ -236,7 +237,7 @@ func restore(file string) (*FileStore, error) {
|
||||
}
|
||||
|
||||
// we need this persist to apply changes we made to account.Peers (we set them to Disconnected)
|
||||
err = store.persist(store.storeFile)
|
||||
err = store.persist(ctx, store.storeFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -246,7 +247,7 @@ func restore(file string) (*FileStore, error) {
|
||||
|
||||
// persist account data to a file
|
||||
// It is recommended to call it with locking FileStore.mux
|
||||
func (s *FileStore) persist(file string) error {
|
||||
func (s *FileStore) persist(ctx context.Context, file string) error {
|
||||
start := time.Now()
|
||||
err := util.WriteJson(file, s)
|
||||
if err != nil {
|
||||
@@ -256,23 +257,23 @@ func (s *FileStore) persist(file string) error {
|
||||
if s.metrics != nil {
|
||||
s.metrics.StoreMetrics().CountPersistenceDuration(took)
|
||||
}
|
||||
log.Debugf("took %d ms to persist the FileStore", took.Milliseconds())
|
||||
log.WithContext(ctx).Debugf("took %d ms to persist the FileStore", took.Milliseconds())
|
||||
return nil
|
||||
}
|
||||
|
||||
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
|
||||
func (s *FileStore) AcquireGlobalLock() (unlock func()) {
|
||||
log.Debugf("acquiring global lock")
|
||||
func (s *FileStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
|
||||
log.WithContext(ctx).Debugf("acquiring global lock")
|
||||
start := time.Now()
|
||||
s.globalAccountLock.Lock()
|
||||
|
||||
unlock = func() {
|
||||
s.globalAccountLock.Unlock()
|
||||
log.Debugf("released global lock in %v", time.Since(start))
|
||||
log.WithContext(ctx).Debugf("released global lock in %v", time.Since(start))
|
||||
}
|
||||
|
||||
took := time.Since(start)
|
||||
log.Debugf("took %v to acquire global lock", took)
|
||||
log.WithContext(ctx).Debugf("took %v to acquire global lock", took)
|
||||
if s.metrics != nil {
|
||||
s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took)
|
||||
}
|
||||
@@ -281,8 +282,8 @@ func (s *FileStore) AcquireGlobalLock() (unlock func()) {
|
||||
}
|
||||
|
||||
// AcquireAccountWriteLock acquires account lock for writing to a resource and returns a function that releases the lock
|
||||
func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
|
||||
log.Debugf("acquiring lock for account %s", accountID)
|
||||
func (s *FileStore) AcquireAccountWriteLock(ctx context.Context, accountID string) (unlock func()) {
|
||||
log.WithContext(ctx).Debugf("acquiring lock for account %s", accountID)
|
||||
start := time.Now()
|
||||
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
|
||||
mtx := value.(*sync.Mutex)
|
||||
@@ -290,7 +291,7 @@ func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
|
||||
|
||||
unlock = func() {
|
||||
mtx.Unlock()
|
||||
log.Debugf("released lock for account %s in %v", accountID, time.Since(start))
|
||||
log.WithContext(ctx).Debugf("released lock for account %s in %v", accountID, time.Since(start))
|
||||
}
|
||||
|
||||
return unlock
|
||||
@@ -298,11 +299,11 @@ func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
|
||||
|
||||
// AcquireAccountReadLock AcquireAccountWriteLock acquires account lock for reading a resource and returns a function that releases the lock
|
||||
// This method is still returns a write lock as file store can't handle read locks
|
||||
func (s *FileStore) AcquireAccountReadLock(accountID string) (unlock func()) {
|
||||
return s.AcquireAccountWriteLock(accountID)
|
||||
func (s *FileStore) AcquireAccountReadLock(ctx context.Context, accountID string) (unlock func()) {
|
||||
return s.AcquireAccountWriteLock(ctx, accountID)
|
||||
}
|
||||
|
||||
func (s *FileStore) SaveAccount(account *Account) error {
|
||||
func (s *FileStore) SaveAccount(ctx context.Context, account *Account) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@@ -338,10 +339,10 @@ func (s *FileStore) SaveAccount(account *Account) error {
|
||||
s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id
|
||||
}
|
||||
|
||||
return s.persist(s.storeFile)
|
||||
return s.persist(ctx, s.storeFile)
|
||||
}
|
||||
|
||||
func (s *FileStore) DeleteAccount(account *Account) error {
|
||||
func (s *FileStore) DeleteAccount(ctx context.Context, account *Account) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@@ -373,7 +374,7 @@ func (s *FileStore) DeleteAccount(account *Account) error {
|
||||
|
||||
delete(s.Accounts, account.Id)
|
||||
|
||||
return s.persist(s.storeFile)
|
||||
return s.persist(ctx, s.storeFile)
|
||||
}
|
||||
|
||||
// DeleteHashedPAT2TokenIDIndex removes an entry from the indexing map HashedPAT2TokenID
|
||||
@@ -397,7 +398,7 @@ func (s *FileStore) DeleteTokenID2UserIDIndex(tokenID string) error {
|
||||
}
|
||||
|
||||
// GetAccountByPrivateDomain returns account by private domain
|
||||
func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
|
||||
func (s *FileStore) GetAccountByPrivateDomain(_ context.Context, domain string) (*Account, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@@ -415,7 +416,7 @@ func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
|
||||
}
|
||||
|
||||
// GetAccountBySetupKey returns account by setup key id
|
||||
func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
|
||||
func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*Account, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@@ -433,7 +434,7 @@ func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
|
||||
}
|
||||
|
||||
// GetTokenIDByHashedToken returns the id of a personal access token by its hashed secret
|
||||
func (s *FileStore) GetTokenIDByHashedToken(token string) (string, error) {
|
||||
func (s *FileStore) GetTokenIDByHashedToken(_ context.Context, token string) (string, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@@ -446,7 +447,7 @@ func (s *FileStore) GetTokenIDByHashedToken(token string) (string, error) {
|
||||
}
|
||||
|
||||
// GetUserByTokenID returns a User object a tokenID belongs to
|
||||
func (s *FileStore) GetUserByTokenID(tokenID string) (*User, error) {
|
||||
func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@@ -469,7 +470,7 @@ func (s *FileStore) GetUserByTokenID(tokenID string) (*User, error) {
|
||||
}
|
||||
|
||||
// GetAllAccounts returns all accounts
|
||||
func (s *FileStore) GetAllAccounts() (all []*Account) {
|
||||
func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
for _, a := range s.Accounts {
|
||||
@@ -490,7 +491,7 @@ func (s *FileStore) getAccount(accountID string) (*Account, error) {
|
||||
}
|
||||
|
||||
// GetAccount returns an account for ID
|
||||
func (s *FileStore) GetAccount(accountID string) (*Account, error) {
|
||||
func (s *FileStore) GetAccount(_ context.Context, accountID string) (*Account, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@@ -503,7 +504,7 @@ func (s *FileStore) GetAccount(accountID string) (*Account, error) {
|
||||
}
|
||||
|
||||
// GetAccountByUser returns a user account
|
||||
func (s *FileStore) GetAccountByUser(userID string) (*Account, error) {
|
||||
func (s *FileStore) GetAccountByUser(_ context.Context, userID string) (*Account, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@@ -521,7 +522,7 @@ func (s *FileStore) GetAccountByUser(userID string) (*Account, error) {
|
||||
}
|
||||
|
||||
// GetAccountByPeerID returns an account for a given peer ID
|
||||
func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) {
|
||||
func (s *FileStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@@ -539,7 +540,7 @@ func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) {
|
||||
// check Account.Peers for a match
|
||||
if _, ok := account.Peers[peerID]; !ok {
|
||||
delete(s.PeerID2AccountID, peerID)
|
||||
log.Warnf("removed stale peerID %s to accountID %s index", peerID, accountID)
|
||||
log.WithContext(ctx).Warnf("removed stale peerID %s to accountID %s index", peerID, accountID)
|
||||
return nil, status.NewPeerNotFoundError(peerID)
|
||||
}
|
||||
|
||||
@@ -547,7 +548,7 @@ func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) {
|
||||
}
|
||||
|
||||
// GetAccountByPeerPubKey returns an account for a given peer WireGuard public key
|
||||
func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
||||
func (s *FileStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@@ -572,14 +573,14 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
||||
}
|
||||
if stale {
|
||||
delete(s.PeerKeyID2AccountID, peerKey)
|
||||
log.Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID)
|
||||
log.WithContext(ctx).Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID)
|
||||
return nil, status.NewPeerNotFoundError(peerKey)
|
||||
}
|
||||
|
||||
return account.Copy(), nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
|
||||
func (s *FileStore) GetAccountIDByPeerPubKey(_ context.Context, peerKey string) (string, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@@ -603,7 +604,7 @@ func (s *FileStore) GetAccountIDByUserID(userID string) (string, error) {
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountIDBySetupKey(setupKey string) (string, error) {
|
||||
func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (string, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@@ -615,7 +616,7 @@ func (s *FileStore) GetAccountIDBySetupKey(setupKey string) (string, error) {
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) {
|
||||
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbpeer.Peer, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@@ -638,7 +639,7 @@ func (s *FileStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) {
|
||||
return nil, status.NewPeerNotFoundError(peerKey)
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountSettings(accountID string) (*Settings, error) {
|
||||
func (s *FileStore) GetAccountSettings(_ context.Context, accountID string) (*Settings, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@@ -656,13 +657,13 @@ func (s *FileStore) GetInstallationID() string {
|
||||
}
|
||||
|
||||
// SaveInstallationID saves the installation ID
|
||||
func (s *FileStore) SaveInstallationID(ID string) error {
|
||||
func (s *FileStore) SaveInstallationID(ctx context.Context, ID string) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
s.InstallationID = ID
|
||||
|
||||
return s.persist(s.storeFile)
|
||||
return s.persist(ctx, s.storeFile)
|
||||
}
|
||||
|
||||
// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things.
|
||||
@@ -732,13 +733,13 @@ func (s *FileStore) GetPostureCheckByChecksDefinition(accountID string, checks *
|
||||
}
|
||||
|
||||
// Close the FileStore persisting data to disk
|
||||
func (s *FileStore) Close() error {
|
||||
func (s *FileStore) Close(ctx context.Context) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
log.Infof("closing FileStore")
|
||||
log.WithContext(ctx).Infof("closing FileStore")
|
||||
|
||||
return s.persist(s.storeFile)
|
||||
return s.persist(ctx, s.storeFile)
|
||||
}
|
||||
|
||||
// GetStoreEngine returns FileStoreEngine
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"net"
|
||||
"path/filepath"
|
||||
@@ -27,12 +28,12 @@ func TestStalePeerIndices(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||
require.NoError(t, err)
|
||||
|
||||
peerID := "some_peer"
|
||||
@@ -42,24 +43,24 @@ func TestStalePeerIndices(t *testing.T) {
|
||||
Key: peerKey,
|
||||
}
|
||||
|
||||
err = store.SaveAccount(account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
account.DeletePeer(peerID)
|
||||
|
||||
err = store.SaveAccount(account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.GetAccountByPeerID(peerID)
|
||||
_, err = store.GetAccountByPeerID(context.Background(), peerID)
|
||||
require.Error(t, err, "expecting to get an error when found stale index")
|
||||
|
||||
_, err = store.GetAccountByPeerPubKey(peerKey)
|
||||
_, err = store.GetAccountByPeerPubKey(context.Background(), peerKey)
|
||||
require.Error(t, err, "expecting to get an error when found stale index")
|
||||
}
|
||||
|
||||
func TestNewStore(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
defer store.Close(context.Background())
|
||||
|
||||
if store.Accounts == nil || len(store.Accounts) != 0 {
|
||||
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
|
||||
@@ -88,9 +89,9 @@ func TestNewStore(t *testing.T) {
|
||||
|
||||
func TestSaveAccount(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
defer store.Close(context.Background())
|
||||
|
||||
account := newAccountWithId("account_id", "testuser", "")
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
setupKey := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
@@ -103,7 +104,7 @@ func TestSaveAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
// SaveAccount should trigger persist
|
||||
err := store.SaveAccount(account)
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -133,11 +134,11 @@ func TestDeleteAccount(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer store.Close()
|
||||
defer store.Close(context.Background())
|
||||
|
||||
var account *Account
|
||||
for _, a := range store.Accounts {
|
||||
@@ -147,7 +148,7 @@ func TestDeleteAccount(t *testing.T) {
|
||||
|
||||
require.NotNil(t, account, "failed to restore a FileStore file and get at least one account")
|
||||
|
||||
err = store.DeleteAccount(account)
|
||||
err = store.DeleteAccount(context.Background(), account)
|
||||
require.NoError(t, err, "failed to delete account, error: %v", err)
|
||||
|
||||
_, ok := store.Accounts[account.Id]
|
||||
@@ -183,9 +184,9 @@ func TestDeleteAccount(t *testing.T) {
|
||||
|
||||
func TestStore(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
defer store.Close(context.Background())
|
||||
|
||||
account := newAccountWithId("account_id", "testuser", "")
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
SetupKey: "peerkeysetupkey",
|
||||
@@ -228,12 +229,12 @@ func TestStore(t *testing.T) {
|
||||
})
|
||||
|
||||
// SaveAccount should trigger persist
|
||||
err := store.SaveAccount(account)
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
restored, err := NewFileStore(store.storeFile, nil)
|
||||
restored, err := NewFileStore(context.Background(), store.storeFile, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -281,7 +282,7 @@ func TestRestore(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -319,7 +320,7 @@ func TestRestoreGroups_Migration(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -332,11 +333,11 @@ func TestRestoreGroups_Migration(t *testing.T) {
|
||||
Name: "All",
|
||||
},
|
||||
}
|
||||
err = store.SaveAccount(account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err, "failed to save account")
|
||||
|
||||
// restore account with default group with empty Issue field
|
||||
if store, err = NewFileStore(storeDir, nil); err != nil {
|
||||
if store, err = NewFileStore(context.Background(), storeDir, nil); err != nil {
|
||||
return
|
||||
}
|
||||
account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
|
||||
@@ -353,18 +354,18 @@ func TestGetAccountByPrivateDomain(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
existingDomain := "test.com"
|
||||
|
||||
account, err := store.GetAccountByPrivateDomain(existingDomain)
|
||||
account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain)
|
||||
require.NoError(t, err, "should found account")
|
||||
require.Equal(t, existingDomain, account.Domain, "domains should match")
|
||||
|
||||
_, err = store.GetAccountByPrivateDomain("missing-domain.com")
|
||||
_, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com")
|
||||
require.Error(t, err, "should return error on domain lookup")
|
||||
}
|
||||
|
||||
@@ -382,7 +383,7 @@ func TestFileStore_GetAccount(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -393,7 +394,7 @@ func TestFileStore_GetAccount(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := store.GetAccount(expected.Id)
|
||||
account, err := store.GetAccount(context.Background(), expected.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -424,13 +425,13 @@ func TestFileStore_GetTokenIDByHashedToken(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
hashedToken := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].HashedToken
|
||||
tokenID, err := store.GetTokenIDByHashedToken(hashedToken)
|
||||
tokenID, err := store.GetTokenIDByHashedToken(context.Background(), hashedToken)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -441,7 +442,7 @@ func TestFileStore_GetTokenIDByHashedToken(t *testing.T) {
|
||||
|
||||
func TestFileStore_DeleteHashedPAT2TokenIDIndex(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
defer store.Close(context.Background())
|
||||
store.HashedPAT2TokenID["someHashedToken"] = "someTokenId"
|
||||
|
||||
err := store.DeleteHashedPAT2TokenIDIndex("someHashedToken")
|
||||
@@ -478,13 +479,13 @@ func TestFileStore_GetTokenIDByHashedToken_Failure(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
wrongToken := sha256.Sum256([]byte("someNotValidTokenThatFails1234"))
|
||||
_, err = store.GetTokenIDByHashedToken(string(wrongToken[:]))
|
||||
_, err = store.GetTokenIDByHashedToken(context.Background(), string(wrongToken[:]))
|
||||
|
||||
assert.Error(t, err, "GetTokenIDByHashedToken should throw error if token invalid")
|
||||
}
|
||||
@@ -503,13 +504,13 @@ func TestFileStore_GetUserByTokenID(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID
|
||||
user, err := store.GetUserByTokenID(tokenID)
|
||||
user, err := store.GetUserByTokenID(context.Background(), tokenID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -531,13 +532,13 @@ func TestFileStore_GetUserByTokenID_Failure(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
wrongTokenID := "someNonExistingTokenID"
|
||||
_, err = store.GetUserByTokenID(wrongTokenID)
|
||||
_, err = store.GetUserByTokenID(context.Background(), wrongTokenID)
|
||||
|
||||
assert.Error(t, err, "GetUserByTokenID should throw error if tokenID invalid")
|
||||
}
|
||||
@@ -550,7 +551,7 @@ func TestFileStore_SavePeerStatus(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -576,7 +577,7 @@ func TestFileStore_SavePeerStatus(t *testing.T) {
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
|
||||
err = store.SaveAccount(account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -602,11 +603,11 @@ func TestFileStore_SavePeerLocation(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||
require.NoError(t, err)
|
||||
|
||||
peer := &nbpeer.Peer{
|
||||
@@ -625,7 +626,7 @@ func TestFileStore_SavePeerLocation(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
|
||||
account.Peers[peer.ID] = peer
|
||||
err = store.SaveAccount(account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer.Location.ConnectionIP = net.ParseIP("35.1.1.1")
|
||||
@@ -636,7 +637,7 @@ func TestFileStore_SavePeerLocation(t *testing.T) {
|
||||
err = store.SavePeerLocation(account.Id, account.Peers[peer.ID])
|
||||
assert.NoError(t, err)
|
||||
|
||||
account, err = store.GetAccount(account.Id)
|
||||
account, err = store.GetAccount(context.Background(), account.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
actual := account.Peers[peer.ID].Location
|
||||
@@ -645,7 +646,7 @@ func TestFileStore_SavePeerLocation(t *testing.T) {
|
||||
|
||||
func newStore(t *testing.T) *FileStore {
|
||||
t.Helper()
|
||||
store, err := NewFileStore(t.TempDir(), nil)
|
||||
store, err := NewFileStore(context.Background(), t.TempDir(), nil)
|
||||
if err != nil {
|
||||
t.Errorf("failed creating a new store")
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package geolocation
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
@@ -52,7 +53,7 @@ type Country struct {
|
||||
CountryName string
|
||||
}
|
||||
|
||||
func NewGeolocation(dataDir string) (*Geolocation, error) {
|
||||
func NewGeolocation(ctx context.Context, dataDir string) (*Geolocation, error) {
|
||||
if err := loadGeolocationDatabases(dataDir); err != nil {
|
||||
return nil, fmt.Errorf("failed to load MaxMind databases: %v", err)
|
||||
}
|
||||
@@ -68,7 +69,7 @@ func NewGeolocation(dataDir string) (*Geolocation, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
locationDB, err := NewSqliteStore(dataDir)
|
||||
locationDB, err := NewSqliteStore(ctx, dataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -83,7 +84,7 @@ func NewGeolocation(dataDir string) (*Geolocation, error) {
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
go geo.reloader()
|
||||
go geo.reloader(ctx)
|
||||
|
||||
return geo, nil
|
||||
}
|
||||
@@ -165,19 +166,19 @@ func (gl *Geolocation) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (gl *Geolocation) reloader() {
|
||||
func (gl *Geolocation) reloader(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-gl.stopCh:
|
||||
return
|
||||
case <-time.After(gl.reloadCheckInterval):
|
||||
if err := gl.locationDB.reload(); err != nil {
|
||||
log.Errorf("geonames db reload failed: %s", err)
|
||||
if err := gl.locationDB.reload(ctx); err != nil {
|
||||
log.WithContext(ctx).Errorf("geonames db reload failed: %s", err)
|
||||
}
|
||||
|
||||
newSha256sum1, err := calculateFileSHA256(gl.mmdbPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
|
||||
log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
|
||||
continue
|
||||
}
|
||||
if !bytes.Equal(gl.sha256sum, newSha256sum1) {
|
||||
@@ -186,30 +187,30 @@ func (gl *Geolocation) reloader() {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
newSha256sum2, err := calculateFileSHA256(gl.mmdbPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
|
||||
log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
|
||||
continue
|
||||
}
|
||||
if !bytes.Equal(newSha256sum1, newSha256sum2) {
|
||||
log.Errorf("sha256 sum changed during reloading of '%s'", gl.mmdbPath)
|
||||
log.WithContext(ctx).Errorf("sha256 sum changed during reloading of '%s'", gl.mmdbPath)
|
||||
continue
|
||||
}
|
||||
err = gl.reload(newSha256sum2)
|
||||
err = gl.reload(ctx, newSha256sum2)
|
||||
if err != nil {
|
||||
log.Errorf("mmdb reload failed: %s", err)
|
||||
log.WithContext(ctx).Errorf("mmdb reload failed: %s", err)
|
||||
}
|
||||
} else {
|
||||
log.Tracef("No changes in '%s', no need to reload. Next check is in %.0f seconds.",
|
||||
log.WithContext(ctx).Tracef("No changes in '%s', no need to reload. Next check is in %.0f seconds.",
|
||||
gl.mmdbPath, gl.reloadCheckInterval.Seconds())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (gl *Geolocation) reload(newSha256sum []byte) error {
|
||||
func (gl *Geolocation) reload(ctx context.Context, newSha256sum []byte) error {
|
||||
gl.mux.Lock()
|
||||
defer gl.mux.Unlock()
|
||||
|
||||
log.Infof("Reloading '%s'", gl.mmdbPath)
|
||||
log.WithContext(ctx).Infof("Reloading '%s'", gl.mmdbPath)
|
||||
|
||||
err := gl.db.Close()
|
||||
if err != nil {
|
||||
@@ -224,7 +225,7 @@ func (gl *Geolocation) reload(newSha256sum []byte) error {
|
||||
gl.db = db
|
||||
gl.sha256sum = newSha256sum
|
||||
|
||||
log.Infof("Successfully reloaded '%s'", gl.mmdbPath)
|
||||
log.WithContext(ctx).Infof("Successfully reloaded '%s'", gl.mmdbPath)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package geolocation
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -50,10 +51,10 @@ type SqliteStore struct {
|
||||
sha256sum []byte
|
||||
}
|
||||
|
||||
func NewSqliteStore(dataDir string) (*SqliteStore, error) {
|
||||
func NewSqliteStore(ctx context.Context, dataDir string) (*SqliteStore, error) {
|
||||
file := filepath.Join(dataDir, GeoSqliteDBFile)
|
||||
|
||||
db, err := connectDB(file)
|
||||
db, err := connectDB(ctx, file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -115,13 +116,13 @@ func (s *SqliteStore) GetCitiesByCountry(countryISOCode string) ([]City, error)
|
||||
}
|
||||
|
||||
// reload attempts to reload the SqliteStore's database if the database file has changed.
|
||||
func (s *SqliteStore) reload() error {
|
||||
func (s *SqliteStore) reload(ctx context.Context) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
newSha256sum1, err := calculateFileSHA256(s.filePath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err)
|
||||
log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(s.sha256sum, newSha256sum1) {
|
||||
@@ -136,11 +137,11 @@ func (s *SqliteStore) reload() error {
|
||||
return fmt.Errorf("sha256 sum changed during reloading of '%s'", s.filePath)
|
||||
}
|
||||
|
||||
log.Infof("Reloading '%s'", s.filePath)
|
||||
log.WithContext(ctx).Infof("Reloading '%s'", s.filePath)
|
||||
_ = s.close()
|
||||
s.closed = true
|
||||
|
||||
newDb, err := connectDB(s.filePath)
|
||||
newDb, err := connectDB(ctx, s.filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -148,9 +149,9 @@ func (s *SqliteStore) reload() error {
|
||||
s.closed = false
|
||||
s.db = newDb
|
||||
|
||||
log.Infof("Successfully reloaded '%s'", s.filePath)
|
||||
log.WithContext(ctx).Infof("Successfully reloaded '%s'", s.filePath)
|
||||
} else {
|
||||
log.Tracef("No changes in '%s', no need to reload", s.filePath)
|
||||
log.WithContext(ctx).Tracef("No changes in '%s', no need to reload", s.filePath)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -168,10 +169,10 @@ func (s *SqliteStore) close() error {
|
||||
}
|
||||
|
||||
// connectDB connects to an SQLite database and prepares it by setting up an in-memory database.
|
||||
func connectDB(filePath string) (*gorm.DB, error) {
|
||||
func connectDB(ctx context.Context, filePath string) (*gorm.DB, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
log.Debugf("took %v to setup geoname db", time.Since(start))
|
||||
log.WithContext(ctx).Debugf("took %v to setup geoname db", time.Since(start))
|
||||
}()
|
||||
|
||||
_, err := fileExists(filePath)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/rs/xid"
|
||||
@@ -21,11 +22,11 @@ func (e *GroupLinkError) Error() string {
|
||||
}
|
||||
|
||||
// GetGroup object of the peers
|
||||
func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -48,11 +49,11 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*n
|
||||
}
|
||||
|
||||
// GetAllGroups returns all groups in an account
|
||||
func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -75,11 +76,11 @@ func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) (
|
||||
}
|
||||
|
||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||
func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -108,11 +109,11 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*n
|
||||
}
|
||||
|
||||
// SaveGroup object of the peers
|
||||
func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *nbgroup.Group) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -150,11 +151,12 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *n
|
||||
account.Groups[newGroup.ID] = newGroup
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(account); err != nil {
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(account)
|
||||
// todo: check if groups is in use by dns, acl, routes and before/after peers
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
// the following snippet tracks the activity and stores the group events in the event store.
|
||||
// It has to happen after all the operations have been successfully performed.
|
||||
@@ -165,16 +167,16 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *n
|
||||
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
|
||||
} else {
|
||||
addedPeers = append(addedPeers, newGroup.Peers...)
|
||||
am.StoreEvent(userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
|
||||
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
|
||||
}
|
||||
|
||||
for _, p := range addedPeers {
|
||||
peer := account.Peers[p]
|
||||
if peer == nil {
|
||||
log.Errorf("peer %s not found under account %s while saving group", p, accountID)
|
||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
|
||||
continue
|
||||
}
|
||||
am.StoreEvent(userID, peer.ID, accountID, activity.GroupAddedToPeer,
|
||||
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer,
|
||||
map[string]any{
|
||||
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
|
||||
"peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||
@@ -184,10 +186,10 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *n
|
||||
for _, p := range removedPeers {
|
||||
peer := account.Peers[p]
|
||||
if peer == nil {
|
||||
log.Errorf("peer %s not found under account %s while saving group", p, accountID)
|
||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
|
||||
continue
|
||||
}
|
||||
am.StoreEvent(userID, peer.ID, accountID, activity.GroupRemovedFromPeer,
|
||||
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer,
|
||||
map[string]any{
|
||||
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
|
||||
"peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||
@@ -213,11 +215,11 @@ func difference(a, b []string) []string {
|
||||
}
|
||||
|
||||
// DeleteGroup object of the peers
|
||||
func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountId)
|
||||
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountId)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountId)
|
||||
account, err := am.Store.GetAccount(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -315,23 +317,24 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string)
|
||||
delete(account.Groups, groupID)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(account); err != nil {
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
|
||||
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
|
||||
|
||||
am.updateAccountPeers(account)
|
||||
// todo: check if groups is in use by dns, acl, routes and if it has peers
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListGroups objects of the peers
|
||||
func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -345,11 +348,11 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group,
|
||||
}
|
||||
|
||||
// GroupAddPeer appends peer to the group
|
||||
func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -371,21 +374,22 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(account); err != nil {
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(account)
|
||||
// todo: check if groups is in use by dns, acl, routes
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GroupDeletePeer removes peer from the group
|
||||
func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -399,13 +403,14 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID stri
|
||||
for i, itemID := range group.Peers {
|
||||
if itemID == peerID {
|
||||
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
|
||||
if err := am.Store.SaveAccount(account); err != nil {
|
||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
am.updateAccountPeers(account)
|
||||
// todo: check if groups is in use by dns, acl, routes
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
@@ -26,7 +27,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
|
||||
}
|
||||
for _, group := range account.Groups {
|
||||
group.Issued = nbgroup.GroupIssuedIntegration
|
||||
err = am.SaveGroup(account.Id, groupAdminUserID, group)
|
||||
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group)
|
||||
if err != nil {
|
||||
t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedIntegration)
|
||||
}
|
||||
@@ -34,7 +35,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
|
||||
|
||||
for _, group := range account.Groups {
|
||||
group.Issued = nbgroup.GroupIssuedJWT
|
||||
err = am.SaveGroup(account.Id, groupAdminUserID, group)
|
||||
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group)
|
||||
if err != nil {
|
||||
t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedJWT)
|
||||
}
|
||||
@@ -42,7 +43,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
|
||||
for _, group := range account.Groups {
|
||||
group.Issued = nbgroup.GroupIssuedAPI
|
||||
group.ID = ""
|
||||
err = am.SaveGroup(account.Id, groupAdminUserID, group)
|
||||
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group)
|
||||
if err == nil {
|
||||
t.Errorf("should not create api group with the same name, %s", group.Name)
|
||||
}
|
||||
@@ -104,7 +105,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
err = am.DeleteGroup(account.Id, groupAdminUserID, testCase.groupID)
|
||||
err = am.DeleteGroup(context.Background(), account.Id, groupAdminUserID, testCase.groupID)
|
||||
if err == nil {
|
||||
t.Errorf("delete %s group successfully", testCase.groupID)
|
||||
return
|
||||
@@ -225,7 +226,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
|
||||
Id: "example user",
|
||||
AutoGroups: []string{groupForUsers.ID},
|
||||
}
|
||||
account := newAccountWithId(accountID, groupAdminUserID, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
|
||||
account.Routes[routeResource.ID] = routeResource
|
||||
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
|
||||
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
|
||||
@@ -233,18 +234,18 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
|
||||
account.SetupKeys[setupKey.Id] = setupKey
|
||||
account.Users[user.Id] = user
|
||||
|
||||
err := am.Store.SaveAccount(account)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = am.SaveGroup(accountID, groupAdminUserID, groupForRoute)
|
||||
_ = am.SaveGroup(accountID, groupAdminUserID, groupForRoute2)
|
||||
_ = am.SaveGroup(accountID, groupAdminUserID, groupForNameServerGroups)
|
||||
_ = am.SaveGroup(accountID, groupAdminUserID, groupForPolicies)
|
||||
_ = am.SaveGroup(accountID, groupAdminUserID, groupForSetupKeys)
|
||||
_ = am.SaveGroup(accountID, groupAdminUserID, groupForUsers)
|
||||
_ = am.SaveGroup(accountID, groupAdminUserID, groupForIntegration)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
|
||||
|
||||
return am.Store.GetAccount(account.Id)
|
||||
return am.Store.GetAccount(context.Background(), account.Id)
|
||||
}
|
||||
|
||||
@@ -11,12 +11,14 @@ import (
|
||||
pb "github.com/golang/protobuf/proto" // nolint
|
||||
"github.com/golang/protobuf/ptypes/timestamp"
|
||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
@@ -40,7 +42,7 @@ type GRPCServer struct {
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) {
|
||||
func NewServer(ctx context.Context, config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -50,6 +52,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
|
||||
|
||||
if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
|
||||
jwtValidator, err = jwtclaims.NewJWTValidator(
|
||||
ctx,
|
||||
config.HttpConfig.AuthIssuer,
|
||||
config.GetAuthAudiences(),
|
||||
config.HttpConfig.AuthKeysLocation,
|
||||
@@ -59,7 +62,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
|
||||
return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)
|
||||
}
|
||||
} else {
|
||||
log.Debug("unable to use http config to create new jwt middleware")
|
||||
log.WithContext(ctx).Debug("unable to use http config to create new jwt middleware")
|
||||
}
|
||||
|
||||
if appMetrics != nil {
|
||||
@@ -126,47 +129,61 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequest()
|
||||
}
|
||||
realIP := getRealIP(srv.Context())
|
||||
log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String())
|
||||
|
||||
ctx := srv.Context()
|
||||
|
||||
realIP := getRealIP(ctx)
|
||||
|
||||
syncReq := &proto.SyncRequest{}
|
||||
peerKey, err := s.parseRequest(req, syncReq)
|
||||
peerKey, err := s.parseRequest(ctx, req, syncReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
||||
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
|
||||
if err != nil {
|
||||
// this case should not happen and already indicates an issue but we don't want the system to fail due to being unable to log in detail
|
||||
accountID = "UNKNOWN"
|
||||
}
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||
|
||||
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String())
|
||||
|
||||
if syncReq.GetMeta() == nil {
|
||||
log.Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
}
|
||||
|
||||
peer, netMap, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), extractPeerMeta(syncReq.GetMeta()), realIP)
|
||||
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
|
||||
if err != nil {
|
||||
return mapError(err)
|
||||
return mapError(ctx, err)
|
||||
}
|
||||
|
||||
err = s.sendInitialSync(peerKey, peer, netMap, srv)
|
||||
err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv)
|
||||
if err != nil {
|
||||
log.Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
||||
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
||||
return err
|
||||
}
|
||||
|
||||
updates := s.peersUpdateManager.CreateChannel(peer.ID)
|
||||
updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID)
|
||||
|
||||
s.ephemeralManager.OnPeerConnected(peer)
|
||||
s.ephemeralManager.OnPeerConnected(ctx, peer)
|
||||
|
||||
if s.config.TURNConfig.TimeBasedCredentials {
|
||||
s.turnCredentialsManager.SetupRefresh(peer.ID)
|
||||
s.turnCredentialsManager.SetupRefresh(ctx, peer.ID)
|
||||
}
|
||||
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
|
||||
}
|
||||
|
||||
return s.handleUpdates(peerKey, peer, updates, srv)
|
||||
return s.handleUpdates(ctx, peerKey, peer, updates, srv)
|
||||
}
|
||||
|
||||
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
||||
func (s *GRPCServer) handleUpdates(peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
for {
|
||||
select {
|
||||
// condition when there are some updates
|
||||
@@ -176,21 +193,21 @@ func (s *GRPCServer) handleUpdates(peerKey wgtypes.Key, peer *nbpeer.Peer, updat
|
||||
}
|
||||
|
||||
if !open {
|
||||
log.Debugf("updates channel for peer %s was closed", peerKey.String())
|
||||
s.cancelPeerRoutines(peer)
|
||||
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
|
||||
s.cancelPeerRoutines(ctx, peer)
|
||||
return nil
|
||||
}
|
||||
log.Debugf("received an update for peer %s", peerKey.String())
|
||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
||||
|
||||
if err := s.sendUpdate(peerKey, peer, update, srv); err != nil {
|
||||
if err := s.sendUpdate(ctx, peerKey, peer, update, srv); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// condition when client <-> server connection has been terminated
|
||||
case <-srv.Context().Done():
|
||||
// happens when connection drops, e.g. client disconnects
|
||||
log.Debugf("stream of peer %s has been closed", peerKey.String())
|
||||
s.cancelPeerRoutines(peer)
|
||||
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
|
||||
s.cancelPeerRoutines(ctx, peer)
|
||||
return srv.Context().Err()
|
||||
}
|
||||
}
|
||||
@@ -198,10 +215,10 @@ func (s *GRPCServer) handleUpdates(peerKey wgtypes.Key, peer *nbpeer.Peer, updat
|
||||
|
||||
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
||||
// then sends the encrypted message to the connected peer via the sync server.
|
||||
func (s *GRPCServer) sendUpdate(peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(peer)
|
||||
s.cancelPeerRoutines(ctx, peer)
|
||||
return status.Errorf(codes.Internal, "failed processing update message")
|
||||
}
|
||||
err = srv.SendMsg(&proto.EncryptedMessage{
|
||||
@@ -209,37 +226,37 @@ func (s *GRPCServer) sendUpdate(peerKey wgtypes.Key, peer *nbpeer.Peer, update *
|
||||
Body: encryptedResp,
|
||||
})
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(peer)
|
||||
s.cancelPeerRoutines(ctx, peer)
|
||||
return status.Errorf(codes.Internal, "failed sending update message")
|
||||
}
|
||||
log.Debugf("sent an update to peer %s", peerKey.String())
|
||||
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) {
|
||||
s.peersUpdateManager.CloseChannel(peer.ID)
|
||||
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) {
|
||||
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
s.turnCredentialsManager.CancelRefresh(peer.ID)
|
||||
_ = s.accountManager.CancelPeerRoutines(peer)
|
||||
s.ephemeralManager.OnPeerDisconnected(peer)
|
||||
_ = s.accountManager.CancelPeerRoutines(ctx, peer)
|
||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||
}
|
||||
|
||||
func (s *GRPCServer) validateToken(jwtToken string) (string, error) {
|
||||
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
|
||||
if s.jwtValidator == nil {
|
||||
return "", status.Error(codes.Internal, "no jwt validator set")
|
||||
}
|
||||
|
||||
token, err := s.jwtValidator.ValidateAndParse(jwtToken)
|
||||
token, err := s.jwtValidator.ValidateAndParse(ctx, jwtToken)
|
||||
if err != nil {
|
||||
return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err)
|
||||
}
|
||||
claims := s.jwtClaimsExtractor.FromToken(token)
|
||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||
_, _, err = s.accountManager.GetAccountFromToken(claims)
|
||||
_, _, err = s.accountManager.GetAccountFromToken(ctx, claims)
|
||||
if err != nil {
|
||||
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
|
||||
}
|
||||
|
||||
if err := s.accountManager.CheckUserAccessByJWTGroups(claims); err != nil {
|
||||
if err := s.accountManager.CheckUserAccessByJWTGroups(ctx, claims); err != nil {
|
||||
return "", status.Errorf(codes.PermissionDenied, err.Error())
|
||||
}
|
||||
|
||||
@@ -247,7 +264,7 @@ func (s *GRPCServer) validateToken(jwtToken string) (string, error) {
|
||||
}
|
||||
|
||||
// maps internal internalStatus.Error to gRPC status.Error
|
||||
func mapError(err error) error {
|
||||
func mapError(ctx context.Context, err error) error {
|
||||
if e, ok := internalStatus.FromError(err); ok {
|
||||
switch e.Type() {
|
||||
case internalStatus.PermissionDenied:
|
||||
@@ -263,11 +280,11 @@ func mapError(err error) error {
|
||||
default:
|
||||
}
|
||||
}
|
||||
log.Errorf("got an unhandled error: %s", err)
|
||||
log.WithContext(ctx).Errorf("got an unhandled error: %s", err)
|
||||
return status.Errorf(codes.Internal, "failed handling request")
|
||||
}
|
||||
|
||||
func extractPeerMeta(meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
|
||||
func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
|
||||
if meta == nil {
|
||||
return nbpeer.PeerSystemMeta{}
|
||||
}
|
||||
@@ -281,7 +298,7 @@ func extractPeerMeta(meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
|
||||
for _, addr := range meta.GetNetworkAddresses() {
|
||||
netAddr, err := netip.ParsePrefix(addr.GetNetIP())
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse netip address, %s: %v", addr.GetNetIP(), err)
|
||||
log.WithContext(ctx).Warnf("failed to parse netip address, %s: %v", addr.GetNetIP(), err)
|
||||
continue
|
||||
}
|
||||
networkAddresses = append(networkAddresses, nbpeer.NetworkAddress{
|
||||
@@ -321,10 +338,10 @@ func extractPeerMeta(meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GRPCServer) parseRequest(req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
|
||||
func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
|
||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||
if err != nil {
|
||||
log.Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey)
|
||||
log.WithContext(ctx).Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey)
|
||||
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey)
|
||||
}
|
||||
|
||||
@@ -351,22 +368,32 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequest()
|
||||
}
|
||||
realIP := getRealIP(ctx)
|
||||
log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String())
|
||||
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String())
|
||||
|
||||
loginReq := &proto.LoginRequest{}
|
||||
peerKey, err := s.parseRequest(req, loginReq)
|
||||
peerKey, err := s.parseRequest(ctx, req, loginReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
||||
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
|
||||
if err != nil {
|
||||
// this case should not happen and already indicates an issue but we don't want the system to fail due to being unable to log in detail
|
||||
accountID = "UNKNOWN"
|
||||
}
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||
|
||||
if loginReq.GetMeta() == nil {
|
||||
msg := status.Errorf(codes.FailedPrecondition,
|
||||
"peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
log.Warn(msg)
|
||||
log.WithContext(ctx).Warn(msg)
|
||||
return nil, msg
|
||||
}
|
||||
|
||||
userID, err := s.processJwtToken(loginReq, peerKey)
|
||||
userID, err := s.processJwtToken(ctx, loginReq, peerKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -376,33 +403,33 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
sshKey = loginReq.GetPeerKeys().GetSshPubKey()
|
||||
}
|
||||
|
||||
peer, netMap, err := s.accountManager.LoginPeer(PeerLogin{
|
||||
peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, PeerLogin{
|
||||
WireGuardPubKey: peerKey.String(),
|
||||
SSHKey: string(sshKey),
|
||||
Meta: extractPeerMeta(loginReq.GetMeta()),
|
||||
Meta: extractPeerMeta(ctx, loginReq.GetMeta()),
|
||||
UserID: userID,
|
||||
SetupKey: loginReq.GetSetupKey(),
|
||||
ConnectionIP: realIP,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warnf("failed logging in peer %s: %s", peerKey, err)
|
||||
return nil, mapError(err)
|
||||
log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err)
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
|
||||
// if the login request contains setup key then it is a registration request
|
||||
if loginReq.GetSetupKey() != "" {
|
||||
s.ephemeralManager.OnPeerDisconnected(peer)
|
||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||
}
|
||||
|
||||
// if peer has reached this point then it has logged in
|
||||
loginResp := &proto.LoginResponse{
|
||||
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil),
|
||||
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()),
|
||||
Checks: toProtocolChecks(s.accountManager, peerKey.String()),
|
||||
Checks: toProtocolChecks(ctx, postureChecks),
|
||||
}
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
|
||||
if err != nil {
|
||||
log.Warnf("failed encrypting peer %s message", peer.ID)
|
||||
log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID)
|
||||
return nil, status.Errorf(codes.Internal, "failed logging in peer")
|
||||
}
|
||||
|
||||
@@ -417,16 +444,16 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
//
|
||||
// The user ID can be empty if the token is not provided, which is acceptable if the peer is already
|
||||
// registered or if it uses a setup key to register.
|
||||
func (s *GRPCServer) processJwtToken(loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) {
|
||||
func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) {
|
||||
userID := ""
|
||||
if loginReq.GetJwtToken() != "" {
|
||||
var err error
|
||||
for i := 0; i < 3; i++ {
|
||||
userID, err = s.validateToken(loginReq.GetJwtToken())
|
||||
userID, err = s.validateToken(ctx, loginReq.GetJwtToken())
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
log.Warnf("failed validating JWT token sent from peer %s with error %v. "+
|
||||
log.WithContext(ctx).Warnf("failed validating JWT token sent from peer %s with error %v. "+
|
||||
"Trying again as it may be due to the IdP cache issue", peerKey.String(), err)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
@@ -520,7 +547,7 @@ func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePee
|
||||
return remotePeers
|
||||
}
|
||||
|
||||
func toSyncResponse(accountManager AccountManager, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse {
|
||||
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse {
|
||||
wtConfig := toWiretrusteeConfig(config, turnCredentials)
|
||||
|
||||
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
|
||||
@@ -551,7 +578,7 @@ func toSyncResponse(accountManager AccountManager, config *Config, peer *nbpeer.
|
||||
FirewallRules: firewallRules,
|
||||
FirewallRulesIsEmpty: len(firewallRules) == 0,
|
||||
},
|
||||
Checks: toProtocolChecks(accountManager, peer.Key),
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -561,7 +588,7 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em
|
||||
}
|
||||
|
||||
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
|
||||
func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error {
|
||||
// make secret time based TURN credentials optional
|
||||
var turnCredentials *TURNCredentials
|
||||
if s.config.TURNConfig.TimeBasedCredentials {
|
||||
@@ -570,7 +597,7 @@ func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, net
|
||||
} else {
|
||||
turnCredentials = nil
|
||||
}
|
||||
plainResp := toSyncResponse(s.accountManager, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain())
|
||||
plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks)
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
||||
if err != nil {
|
||||
@@ -583,7 +610,7 @@ func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, net
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed sending SyncResponse %v", err)
|
||||
log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err)
|
||||
return status.Errorf(codes.Internal, "error handling request")
|
||||
}
|
||||
|
||||
@@ -597,14 +624,14 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
|
||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||
if err != nil {
|
||||
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey)
|
||||
log.Warn(errMSG)
|
||||
log.WithContext(ctx).Warn(errMSG)
|
||||
return nil, status.Error(codes.InvalidArgument, errMSG)
|
||||
}
|
||||
|
||||
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.DeviceAuthorizationFlowRequest{})
|
||||
if err != nil {
|
||||
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
|
||||
log.Warn(errMSG)
|
||||
log.WithContext(ctx).Warn(errMSG)
|
||||
return nil, status.Error(codes.InvalidArgument, errMSG)
|
||||
}
|
||||
|
||||
@@ -645,18 +672,18 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
|
||||
// GetPKCEAuthorizationFlow returns a pkce authorization flow information
|
||||
// This is used for initiating an Oauth 2 pkce authorization grant flow
|
||||
// which will be used by our clients to Login
|
||||
func (s *GRPCServer) GetPKCEAuthorizationFlow(_ context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||
if err != nil {
|
||||
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetPKCEAuthorizationFlow request.", req.WgPubKey)
|
||||
log.Warn(errMSG)
|
||||
log.WithContext(ctx).Warn(errMSG)
|
||||
return nil, status.Error(codes.InvalidArgument, errMSG)
|
||||
}
|
||||
|
||||
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.PKCEAuthorizationFlowRequest{})
|
||||
if err != nil {
|
||||
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
|
||||
log.Warn(errMSG)
|
||||
log.WithContext(ctx).Warn(errMSG)
|
||||
return nil, status.Error(codes.InvalidArgument, errMSG)
|
||||
}
|
||||
|
||||
@@ -692,10 +719,10 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(_ context.Context, req *proto.Encr
|
||||
// peer's under the same account of any updates.
|
||||
func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
|
||||
realIP := getRealIP(ctx)
|
||||
log.Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String())
|
||||
log.WithContext(ctx).Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String())
|
||||
|
||||
syncMetaReq := &proto.SyncMetaRequest{}
|
||||
peerKey, err := s.parseRequest(req, syncMetaReq)
|
||||
peerKey, err := s.parseRequest(ctx, req, syncMetaReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -703,27 +730,21 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage)
|
||||
if syncMetaReq.GetMeta() == nil {
|
||||
msg := status.Errorf(codes.FailedPrecondition,
|
||||
"peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
log.Warn(msg)
|
||||
log.WithContext(ctx).Warn(msg)
|
||||
return nil, msg
|
||||
}
|
||||
|
||||
err = s.accountManager.SyncPeerMeta(peerKey.String(), extractPeerMeta(syncMetaReq.GetMeta()))
|
||||
err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()))
|
||||
if err != nil {
|
||||
return nil, mapError(err)
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
|
||||
return &proto.Empty{}, nil
|
||||
}
|
||||
|
||||
// toProtocolChecks returns posture checks for the peer that needs to be evaluated on the client side.
|
||||
func toProtocolChecks(accountManager AccountManager, peerKey string) []*proto.Checks {
|
||||
postureChecks, err := accountManager.GetPeerAppliedPostureChecks(peerKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed getting peer's: %s posture checks: %v", peerKey, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
protoChecks := make([]*proto.Checks, 0)
|
||||
// toProtocolChecks converts posture checks to protocol checks.
|
||||
func toProtocolChecks(ctx context.Context, postureChecks []*posture.Checks) []*proto.Checks {
|
||||
protoChecks := make([]*proto.Checks, 0, len(postureChecks))
|
||||
for _, postureCheck := range postureChecks {
|
||||
protoChecks = append(protoChecks, toProtocolCheck(postureCheck))
|
||||
}
|
||||
@@ -732,7 +753,7 @@ func toProtocolChecks(accountManager AccountManager, peerKey string) []*proto.Ch
|
||||
}
|
||||
|
||||
// toProtocolCheck converts a posture.Checks to a proto.Checks.
|
||||
func toProtocolCheck(postureCheck posture.Checks) *proto.Checks {
|
||||
func toProtocolCheck(postureCheck *posture.Checks) *proto.Checks {
|
||||
protoCheck := &proto.Checks{}
|
||||
|
||||
if check := postureCheck.Checks.ProcessCheck; check != nil {
|
||||
|
||||
@@ -35,34 +35,34 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *
|
||||
// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
|
||||
func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
||||
util.WriteError(status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(account)
|
||||
util.WriteJSONObject(w, []*api.Account{resp})
|
||||
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
||||
}
|
||||
|
||||
// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
|
||||
func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
_, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
_, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
accountID := vars["accountId"]
|
||||
if len(accountID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid accountID ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid accountID ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -96,15 +96,15 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
|
||||
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
|
||||
}
|
||||
|
||||
updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, settings)
|
||||
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, user.Id, settings)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(updatedAccount)
|
||||
|
||||
util.WriteJSONObject(w, &resp)
|
||||
util.WriteJSONObject(r.Context(), w, &resp)
|
||||
}
|
||||
|
||||
// DeleteAccount is a HTTP DELETE handler to delete an account
|
||||
@@ -118,17 +118,17 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request)
|
||||
vars := mux.Vars(r)
|
||||
targetAccountID := vars["accountId"]
|
||||
if len(targetAccountID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid account ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid account ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.accountManager.DeleteAccount(targetAccountID, claims.UserId)
|
||||
err := h.accountManager.DeleteAccount(r.Context(), targetAccountID, claims.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
||||
}
|
||||
|
||||
func toAccountResponse(account *server.Account) *api.Account {
|
||||
|
||||
@@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -22,10 +23,10 @@ import (
|
||||
func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler {
|
||||
return &AccountsHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return account, admin, nil
|
||||
},
|
||||
UpdateAccountSettingsFunc: func(accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
|
||||
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
|
||||
halfYearLimit := 180 * 24 * time.Hour
|
||||
if newSettings.PeerLoginExpiration > halfYearLimit {
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
|
||||
|
||||
@@ -32,16 +32,16 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg
|
||||
// GetDNSSettings returns the DNS settings for the account
|
||||
func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
dnsSettings, err := h.accountManager.GetDNSSettings(account.Id, user.Id)
|
||||
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -49,15 +49,15 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque
|
||||
DisabledManagementGroups: dnsSettings.DisabledManagementGroups,
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, apiDNSSettings)
|
||||
util.WriteJSONObject(r.Context(), w, apiDNSSettings)
|
||||
}
|
||||
|
||||
// UpdateDNSSettings handles update to DNS settings of an account
|
||||
func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -72,9 +72,9 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re
|
||||
DisabledManagementGroups: req.DisabledManagementGroups,
|
||||
}
|
||||
|
||||
err = h.accountManager.SaveDNSSettings(account.Id, user.Id, updateDNSSettings)
|
||||
err = h.accountManager.SaveDNSSettings(r.Context(), account.Id, user.Id, updateDNSSettings)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -82,5 +82,5 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re
|
||||
DisabledManagementGroups: updateDNSSettings.DisabledManagementGroups,
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, &resp)
|
||||
util.WriteJSONObject(r.Context(), w, &resp)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -42,16 +43,16 @@ var testingDNSSettingsAccount = &server.Account{
|
||||
func initDNSSettingsTestData() *DNSSettingsHandler {
|
||||
return &DNSSettingsHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetDNSSettingsFunc: func(accountID string, userID string) (*server.DNSSettings, error) {
|
||||
GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) {
|
||||
return &testingDNSSettingsAccount.DNSSettings, nil
|
||||
},
|
||||
SaveDNSSettingsFunc: func(accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error {
|
||||
SaveDNSSettingsFunc: func(ctx context.Context, accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error {
|
||||
if dnsSettingsToSave != nil {
|
||||
return nil
|
||||
}
|
||||
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil
|
||||
},
|
||||
},
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
@@ -33,16 +34,16 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev
|
||||
// GetAllEvents list of the given account
|
||||
func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
accountEvents, err := h.accountManager.GetEvents(account.Id, user.Id)
|
||||
accountEvents, err := h.accountManager.GetEvents(r.Context(), account.Id, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
events := make([]*api.Event, len(accountEvents))
|
||||
@@ -50,20 +51,20 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
|
||||
events[i] = toEventResponse(e)
|
||||
}
|
||||
|
||||
err = h.fillEventsWithUserInfo(events, account.Id, user.Id)
|
||||
err = h.fillEventsWithUserInfo(r.Context(), events, account.Id, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, events)
|
||||
util.WriteJSONObject(r.Context(), w, events)
|
||||
}
|
||||
|
||||
func (h *EventsHandler) fillEventsWithUserInfo(events []*api.Event, accountId, userId string) error {
|
||||
func (h *EventsHandler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error {
|
||||
// build email, name maps based on users
|
||||
userInfos, err := h.accountManager.GetUsersFromAccount(accountId, userId)
|
||||
userInfos, err := h.accountManager.GetUsersFromAccount(ctx, accountId, userId)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get users from account: %s", err)
|
||||
log.WithContext(ctx).Errorf("failed to get users from account: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -80,7 +81,7 @@ func (h *EventsHandler) fillEventsWithUserInfo(events []*api.Event, accountId, u
|
||||
if event.InitiatorEmail == "" {
|
||||
event.InitiatorEmail, ok = emails[event.InitiatorId]
|
||||
if !ok {
|
||||
log.Warnf("failed to resolve email for initiator: %s", event.InitiatorId)
|
||||
log.WithContext(ctx).Warnf("failed to resolve email for initiator: %s", event.InitiatorId)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -22,13 +23,13 @@ import (
|
||||
func initEventsTestData(account string, user *server.User, events ...*activity.Event) *EventsHandler {
|
||||
return &EventsHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetEventsFunc: func(accountID, userID string) ([]*activity.Event, error) {
|
||||
GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) {
|
||||
if accountID == account {
|
||||
return events, nil
|
||||
}
|
||||
return []*activity.Event{}, nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Domain: "hotmail.com",
|
||||
@@ -37,7 +38,7 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E
|
||||
},
|
||||
}, user, nil
|
||||
},
|
||||
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
|
||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
|
||||
return make([]*server.UserInfo, 0), nil
|
||||
},
|
||||
},
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -35,13 +36,13 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
|
||||
err = util.CopyFileContents(geonamesDBPath, path.Join(tempDir, geolocation.GeoSqliteDBFile))
|
||||
assert.NoError(t, err)
|
||||
|
||||
geo, err := geolocation.NewGeolocation(tempDir)
|
||||
geo, err := geolocation.NewGeolocation(context.Background(), tempDir)
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(func() { _ = geo.Stop() })
|
||||
|
||||
return &GeolocationsHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
user := server.NewAdminUser("test_user")
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
|
||||
@@ -40,19 +40,19 @@ func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geoloca
|
||||
// GetAllCountries retrieves a list of all countries
|
||||
func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Request) {
|
||||
if err := l.authenticateUser(r); err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if l.geolocationManager == nil {
|
||||
// TODO: update error message to include geo db self hosted doc link when ready
|
||||
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
|
||||
return
|
||||
}
|
||||
|
||||
allCountries, err := l.geolocationManager.GetAllCountries()
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -60,32 +60,32 @@ func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Req
|
||||
for _, country := range allCountries {
|
||||
countries = append(countries, toCountryResponse(country))
|
||||
}
|
||||
util.WriteJSONObject(w, countries)
|
||||
util.WriteJSONObject(r.Context(), w, countries)
|
||||
}
|
||||
|
||||
// GetCitiesByCountry retrieves a list of cities based on the given country code
|
||||
func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.Request) {
|
||||
if err := l.authenticateUser(r); err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
countryCode := vars["country"]
|
||||
if !countryCodeRegex.MatchString(countryCode) {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid country code"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid country code"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if l.geolocationManager == nil {
|
||||
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
|
||||
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
|
||||
"Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w)
|
||||
return
|
||||
}
|
||||
|
||||
allCities, err := l.geolocationManager.GetCitiesByCountry(countryCode)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -93,12 +93,12 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.
|
||||
for _, city := range allCities {
|
||||
cities = append(cities, toCityResponse(city))
|
||||
}
|
||||
util.WriteJSONObject(w, cities)
|
||||
util.WriteJSONObject(r.Context(), w, cities)
|
||||
}
|
||||
|
||||
func (l *GeolocationsHandler) authenticateUser(r *http.Request) error {
|
||||
claims := l.claimsExtractor.FromRequestContext(r)
|
||||
_, user, err := l.accountManager.GetAccountFromToken(claims)
|
||||
_, user, err := l.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -35,16 +35,16 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr
|
||||
// GetAllGroups list for the account
|
||||
func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
groups, err := h.accountManager.GetAllGroups(account.Id, user.Id)
|
||||
groups, err := h.accountManager.GetAllGroups(r.Context(), account.Id, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -53,42 +53,42 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||
groupsResponse = append(groupsResponse, toGroupResponse(account, group))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, groupsResponse)
|
||||
util.WriteJSONObject(r.Context(), w, groupsResponse)
|
||||
}
|
||||
|
||||
// UpdateGroup handles update to a group identified by a given ID
|
||||
func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
groupID, ok := vars["groupId"]
|
||||
if !ok {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "group ID field is missing"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group ID field is missing"), w)
|
||||
return
|
||||
}
|
||||
if len(groupID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "group ID can't be empty"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group ID can't be empty"), w)
|
||||
return
|
||||
}
|
||||
|
||||
eg, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
util.WriteError(status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
|
||||
return
|
||||
}
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
if allGroup.ID == groupID {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -100,7 +100,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -118,21 +118,21 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
IntegrationReference: eg.IntegrationReference,
|
||||
}
|
||||
|
||||
if err := h.accountManager.SaveGroup(account.Id, user.Id, &group); err != nil {
|
||||
log.Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
|
||||
util.WriteError(err, w)
|
||||
if err := h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group); err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, toGroupResponse(account, &group))
|
||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
|
||||
}
|
||||
|
||||
// CreateGroup handles group creation request
|
||||
func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -144,7 +144,7 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -160,62 +160,62 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
Issued: nbgroup.GroupIssuedAPI,
|
||||
}
|
||||
|
||||
err = h.accountManager.SaveGroup(account.Id, user.Id, &group)
|
||||
err = h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, toGroupResponse(account, &group))
|
||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
|
||||
}
|
||||
|
||||
// DeleteGroup handles group deletion request
|
||||
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
aID := account.Id
|
||||
|
||||
groupID := mux.Vars(r)["groupId"]
|
||||
if len(groupID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if allGroup.ID == groupID {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteGroup(aID, user.Id, groupID)
|
||||
err = h.accountManager.DeleteGroup(r.Context(), aID, user.Id, groupID)
|
||||
if err != nil {
|
||||
_, ok := err.(*server.GroupLinkError)
|
||||
if ok {
|
||||
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
||||
}
|
||||
|
||||
// GetGroup returns a group
|
||||
func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -223,19 +223,19 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
|
||||
case http.MethodGet:
|
||||
groupID := mux.Vars(r)["groupId"]
|
||||
if len(groupID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
group, err := h.accountManager.GetGroup(account.Id, groupID, user.Id)
|
||||
group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, toGroupResponse(account, group))
|
||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group))
|
||||
default:
|
||||
util.WriteError(status.Errorf(status.NotFound, "HTTP method not found"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "HTTP method not found"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -32,13 +33,13 @@ var TestPeers = map[string]*nbpeer.Peer{
|
||||
func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
|
||||
return &GroupsHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
SaveGroupFunc: func(accountID, userID string, group *nbgroup.Group) error {
|
||||
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error {
|
||||
if !strings.HasPrefix(group.ID, "id-") {
|
||||
group.ID = "id-was-set"
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetGroupFunc: func(_, groupID, _ string) (*nbgroup.Group, error) {
|
||||
GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) {
|
||||
if groupID != "idofthegroup" {
|
||||
return nil, status.Errorf(status.NotFound, "not found")
|
||||
}
|
||||
@@ -55,7 +56,7 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
|
||||
Issued: nbgroup.GroupIssuedAPI,
|
||||
}, nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Domain: "hotmail.com",
|
||||
@@ -70,7 +71,7 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
|
||||
},
|
||||
}, user, nil
|
||||
},
|
||||
DeleteGroupFunc: func(accountID, userId, groupID string) error {
|
||||
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
|
||||
if groupID == "linked-grp" {
|
||||
return &server.GroupLinkError{
|
||||
Resource: "something",
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/rs/cors"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
s "github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
@@ -57,6 +58,11 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
|
||||
|
||||
corsMiddleware := cors.AllowAll()
|
||||
|
||||
claimsExtractor = jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
)
|
||||
|
||||
acMiddleware := middleware.NewAccessControl(
|
||||
authCfg.Audience,
|
||||
authCfg.UserIDClaim,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"regexp"
|
||||
|
||||
@@ -15,7 +16,7 @@ import (
|
||||
)
|
||||
|
||||
// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
|
||||
type GetUser func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
|
||||
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
|
||||
type AccessControl struct {
|
||||
@@ -46,15 +47,15 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
|
||||
|
||||
claims := a.claimsExtract.FromRequestContext(r)
|
||||
|
||||
user, err := a.getUser(claims)
|
||||
user, err := a.getUser(r.Context(), claims)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get user from claims: %s", err)
|
||||
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
|
||||
log.WithContext(r.Context()).Errorf("failed to get user from claims: %s", err)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid JWT"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if user.IsBlocked() {
|
||||
util.WriteError(status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -63,12 +64,12 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
|
||||
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
|
||||
|
||||
if tokenPathRegexp.MatchString(r.URL.Path) {
|
||||
log.Debugf("valid Path")
|
||||
log.WithContext(r.Context()).Debugf("valid Path")
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteError(status.Errorf(status.PermissionDenied, "only users with admin power can perform this operation"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only users with admin power can perform this operation"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
@@ -19,16 +20,16 @@ import (
|
||||
)
|
||||
|
||||
// GetAccountFromPATFunc function
|
||||
type GetAccountFromPATFunc func(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
|
||||
// ValidateAndParseTokenFunc function
|
||||
type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error)
|
||||
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
|
||||
|
||||
// MarkPATUsedFunc function
|
||||
type MarkPATUsedFunc func(token string) error
|
||||
type MarkPATUsedFunc func(ctx context.Context, token string) error
|
||||
|
||||
// CheckUserAccessByJWTGroupsFunc function
|
||||
type CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
|
||||
type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||
|
||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
@@ -85,23 +86,27 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
case "bearer":
|
||||
err := m.checkJWTFromRequest(w, r, auth)
|
||||
if err != nil {
|
||||
log.Errorf("Error when validating JWT claims: %s", err.Error())
|
||||
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
log.WithContext(r.Context()).Errorf("Error when validating JWT claims: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(w, r)
|
||||
case "token":
|
||||
err := m.checkPATFromRequest(w, r, auth)
|
||||
if err != nil {
|
||||
log.Debugf("Error when validating PAT claims: %s", err.Error())
|
||||
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
log.WithContext(r.Context()).Debugf("Error when validating PAT claims: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(w, r)
|
||||
default:
|
||||
util.WriteError(status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
|
||||
return
|
||||
}
|
||||
claims := m.claimsExtractor.FromRequestContext(r)
|
||||
//nolint
|
||||
ctx := context.WithValue(r.Context(), nbContext.UserIDKey, claims.UserId)
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, claims.AccountId)
|
||||
h.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -114,7 +119,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
return fmt.Errorf("Error extracting token: %w", err)
|
||||
}
|
||||
|
||||
validatedToken, err := m.validateAndParseToken(token)
|
||||
validatedToken, err := m.validateAndParseToken(r.Context(), token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -123,7 +128,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := m.verifyUserAccess(validatedToken); err != nil {
|
||||
if err := m.verifyUserAccess(r.Context(), validatedToken); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -138,9 +143,9 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
// verifyUserAccess checks if a user, based on a validated JWT token,
|
||||
// is allowed access, particularly in cases where the admin enabled JWT
|
||||
// group propagation and designated certain groups with access permissions.
|
||||
func (m *AuthMiddleware) verifyUserAccess(validatedToken *jwt.Token) error {
|
||||
func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *jwt.Token) error {
|
||||
authClaims := m.claimsExtractor.FromToken(validatedToken)
|
||||
return m.checkUserAccessByJWTGroups(authClaims)
|
||||
return m.checkUserAccessByJWTGroups(ctx, authClaims)
|
||||
}
|
||||
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
@@ -152,7 +157,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
return fmt.Errorf("Error extracting token: %w", err)
|
||||
}
|
||||
|
||||
account, user, pat, err := m.getAccountFromPAT(token)
|
||||
account, user, pat, err := m.getAccountFromPAT(r.Context(), token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid Token: %w", err)
|
||||
}
|
||||
@@ -160,7 +165,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
return fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
err = m.markPATUsed(pat.ID)
|
||||
err = m.markPATUsed(r.Context(), pat.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -15,15 +16,16 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
audience = "audience"
|
||||
userIDClaim = "userIDClaim"
|
||||
accountID = "accountID"
|
||||
domain = "domain"
|
||||
userID = "userID"
|
||||
tokenID = "tokenID"
|
||||
PAT = "nbp_PAT"
|
||||
JWT = "JWT"
|
||||
wrongToken = "wrongToken"
|
||||
audience = "audience"
|
||||
userIDClaim = "userIDClaim"
|
||||
accountID = "accountID"
|
||||
domain = "domain"
|
||||
domainCategory = "domainCategory"
|
||||
userID = "userID"
|
||||
tokenID = "tokenID"
|
||||
PAT = "nbp_PAT"
|
||||
JWT = "JWT"
|
||||
wrongToken = "wrongToken"
|
||||
)
|
||||
|
||||
var testAccount = &server.Account{
|
||||
@@ -47,14 +49,14 @@ var testAccount = &server.Account{
|
||||
},
|
||||
}
|
||||
|
||||
func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
|
||||
func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
|
||||
if token == PAT {
|
||||
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
|
||||
}
|
||||
return nil, nil, nil, fmt.Errorf("PAT invalid")
|
||||
}
|
||||
|
||||
func mockValidateAndParseToken(token string) (*jwt.Token, error) {
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
|
||||
if token == JWT {
|
||||
return &jwt.Token{
|
||||
Claims: jwt.MapClaims{
|
||||
@@ -67,14 +69,14 @@ func mockValidateAndParseToken(token string) (*jwt.Token, error) {
|
||||
return nil, fmt.Errorf("JWT invalid")
|
||||
}
|
||||
|
||||
func mockMarkPATUsed(token string) error {
|
||||
func mockMarkPATUsed(_ context.Context, token string) error {
|
||||
if token == tokenID {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("Should never get reached")
|
||||
}
|
||||
|
||||
func mockCheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error {
|
||||
func mockCheckUserAccessByJWTGroups(_ context.Context, claims jwtclaims.AuthorizationClaims) error {
|
||||
if testAccount.Id != claims.AccountId {
|
||||
return fmt.Errorf("account with id %s does not exist", claims.AccountId)
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ func ShouldBypass(requestPath string, h http.Handler, w http.ResponseWriter, r *
|
||||
for bypassPath := range bypassPaths {
|
||||
matched, err := path.Match(bypassPath, requestPath)
|
||||
if err != nil {
|
||||
log.Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err)
|
||||
log.WithContext(r.Context()).Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err)
|
||||
continue
|
||||
}
|
||||
if matched {
|
||||
|
||||
@@ -36,16 +36,16 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg
|
||||
// GetAllNameservers returns the list of nameserver groups for the account
|
||||
func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
nsGroups, err := h.accountManager.ListNameServerGroups(account.Id, user.Id)
|
||||
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), account.Id, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -54,15 +54,15 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re
|
||||
apiNameservers = append(apiNameservers, toNameserverGroupResponse(r))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, apiNameservers)
|
||||
util.WriteJSONObject(r.Context(), w, apiNameservers)
|
||||
}
|
||||
|
||||
// CreateNameserverGroup handles nameserver group creation request
|
||||
func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -75,33 +75,33 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
|
||||
|
||||
nsList, err := toServerNSList(req.Nameservers)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
|
||||
return
|
||||
}
|
||||
|
||||
nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled)
|
||||
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toNameserverGroupResponse(nsGroup)
|
||||
|
||||
util.WriteJSONObject(w, &resp)
|
||||
util.WriteJSONObject(r.Context(), w, &resp)
|
||||
}
|
||||
|
||||
// UpdateNameserverGroup handles update to a nameserver group identified by a given ID
|
||||
func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
nsGroupID := mux.Vars(r)["nsgroupId"]
|
||||
if len(nsGroupID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -114,7 +114,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
|
||||
|
||||
nsList, err := toServerNSList(req.Nameservers)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -130,66 +130,66 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
|
||||
SearchDomainsEnabled: req.SearchDomainsEnabled,
|
||||
}
|
||||
|
||||
err = h.accountManager.SaveNameServerGroup(account.Id, user.Id, updatedNSGroup)
|
||||
err = h.accountManager.SaveNameServerGroup(r.Context(), account.Id, user.Id, updatedNSGroup)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toNameserverGroupResponse(updatedNSGroup)
|
||||
|
||||
util.WriteJSONObject(w, &resp)
|
||||
util.WriteJSONObject(r.Context(), w, &resp)
|
||||
}
|
||||
|
||||
// DeleteNameserverGroup handles nameserver group deletion request
|
||||
func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
nsGroupID := mux.Vars(r)["nsgroupId"]
|
||||
if len(nsGroupID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteNameServerGroup(account.Id, nsGroupID, user.Id)
|
||||
err = h.accountManager.DeleteNameServerGroup(r.Context(), account.Id, nsGroupID, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
||||
}
|
||||
|
||||
// GetNameserverGroup handles a nameserver group Get request identified by ID
|
||||
func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
nsGroupID := mux.Vars(r)["nsgroupId"]
|
||||
if len(nsGroupID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
nsGroup, err := h.accountManager.GetNameServerGroup(account.Id, user.Id, nsGroupID)
|
||||
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), account.Id, user.Id, nsGroupID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toNameserverGroupResponse(nsGroup)
|
||||
|
||||
util.WriteJSONObject(w, &resp)
|
||||
util.WriteJSONObject(r.Context(), w, &resp)
|
||||
}
|
||||
|
||||
func toServerNSList(apiNSList []api.Nameserver) ([]nbdns.NameServer, error) {
|
||||
|
||||
@@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -61,13 +62,13 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{
|
||||
func initNameserversTestData() *NameserversHandler {
|
||||
return &NameserversHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetNameServerGroupFunc: func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||
GetNameServerGroupFunc: func(_ context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||
if nsGroupID == existingNSGroupID {
|
||||
return baseExistingNSGroup.Copy(), nil
|
||||
}
|
||||
return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
|
||||
},
|
||||
CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, _ string, searchDomains bool) (*nbdns.NameServerGroup, error) {
|
||||
CreateNameServerGroupFunc: func(_ context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, _ string, searchDomains bool) (*nbdns.NameServerGroup, error) {
|
||||
return &nbdns.NameServerGroup{
|
||||
ID: existingNSGroupID,
|
||||
Name: name,
|
||||
@@ -80,16 +81,16 @@ func initNameserversTestData() *NameserversHandler {
|
||||
SearchDomainsEnabled: searchDomains,
|
||||
}, nil
|
||||
},
|
||||
DeleteNameServerGroupFunc: func(accountID, nsGroupID, _ string) error {
|
||||
DeleteNameServerGroupFunc: func(_ context.Context, accountID, nsGroupID, _ string) error {
|
||||
return nil
|
||||
},
|
||||
SaveNameServerGroupFunc: func(accountID, _ string, nsGroupToSave *nbdns.NameServerGroup) error {
|
||||
SaveNameServerGroupFunc: func(_ context.Context, accountID, _ string, nsGroupToSave *nbdns.NameServerGroup) error {
|
||||
if nsGroupToSave.ID == existingNSGroupID {
|
||||
return nil
|
||||
}
|
||||
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return testingNSAccount, testingAccount.Users["test_user"], nil
|
||||
},
|
||||
},
|
||||
|
||||
@@ -34,22 +34,22 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH
|
||||
// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
|
||||
func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
userID := vars["userId"]
|
||||
if len(userID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
pats, err := h.accountManager.GetAllPATs(account.Id, user.Id, userID)
|
||||
pats, err := h.accountManager.GetAllPATs(r.Context(), account.Id, user.Id, userID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -58,53 +58,53 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
|
||||
patResponse = append(patResponse, toPATResponse(pat))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, patResponse)
|
||||
util.WriteJSONObject(r.Context(), w, patResponse)
|
||||
}
|
||||
|
||||
// GetToken is HTTP GET handler that returns a personal access token for the given user
|
||||
func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
tokenID := vars["tokenId"]
|
||||
if len(tokenID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid token ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid token ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
pat, err := h.accountManager.GetPAT(account.Id, user.Id, targetUserID, tokenID)
|
||||
pat, err := h.accountManager.GetPAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, toPATResponse(pat))
|
||||
util.WriteJSONObject(r.Context(), w, toPATResponse(pat))
|
||||
}
|
||||
|
||||
// CreateToken is HTTP POST handler that creates a personal access token for the given user
|
||||
func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -115,44 +115,44 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
pat, err := h.accountManager.CreatePAT(account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn)
|
||||
pat, err := h.accountManager.CreatePAT(r.Context(), account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, toPATGeneratedResponse(pat))
|
||||
util.WriteJSONObject(r.Context(), w, toPATGeneratedResponse(pat))
|
||||
}
|
||||
|
||||
// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user
|
||||
func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
tokenID := vars["tokenId"]
|
||||
if len(tokenID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid token ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid token ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeletePAT(account.Id, user.Id, targetUserID, tokenID)
|
||||
err = h.accountManager.DeletePAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
||||
}
|
||||
|
||||
func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken {
|
||||
|
||||
@@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -63,7 +64,7 @@ var testAccount = &server.Account{
|
||||
func initPATTestData() *PATHandler {
|
||||
return &PATHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
CreatePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||
CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||
if accountID != existingAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
@@ -76,10 +77,10 @@ func initPATTestData() *PATHandler {
|
||||
}, nil
|
||||
},
|
||||
|
||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return testAccount, testAccount.Users[existingUserID], nil
|
||||
},
|
||||
DeletePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||
DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||
if accountID != existingAccountID {
|
||||
return status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
@@ -91,7 +92,7 @@ func initPATTestData() *PATHandler {
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetPATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
||||
GetPATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
||||
if accountID != existingAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
@@ -103,7 +104,7 @@ func initPATTestData() *PATHandler {
|
||||
}
|
||||
return testAccount.Users[existingUserID].PATs[existingTokenID], nil
|
||||
},
|
||||
GetAllPATsFunc: func(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
||||
GetAllPATsFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
||||
if accountID != existingAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -47,16 +48,16 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
return peerToReturn, nil
|
||||
}
|
||||
|
||||
func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w http.ResponseWriter) {
|
||||
peer, err := h.accountManager.GetPeer(account.Id, peerID, userID)
|
||||
func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) {
|
||||
peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
peerToReturn, err := h.checkPeerStatus(peer)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
@@ -65,19 +66,19 @@ func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(account)
|
||||
if err != nil {
|
||||
log.Errorf("failed to list appreoved peers: %v", err)
|
||||
util.WriteError(fmt.Errorf("internal error"), w)
|
||||
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
|
||||
netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validPeers)
|
||||
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers)
|
||||
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
|
||||
|
||||
_, valid := validPeers[peer.ID]
|
||||
util.WriteJSONObject(w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid))
|
||||
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid))
|
||||
}
|
||||
|
||||
func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||
req := &api.PeerRequest{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
@@ -99,9 +100,9 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe
|
||||
}
|
||||
}
|
||||
|
||||
peer, err := h.accountManager.UpdatePeer(account.Id, user.Id, update)
|
||||
peer, err := h.accountManager.UpdatePeer(ctx, account.Id, user.Id, update)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
@@ -110,75 +111,75 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(account)
|
||||
if err != nil {
|
||||
log.Errorf("failed to list appreoved peers: %v", err)
|
||||
util.WriteError(fmt.Errorf("internal error"), w)
|
||||
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validPeers)
|
||||
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers)
|
||||
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
|
||||
|
||||
_, valid := validPeers[peer.ID]
|
||||
|
||||
util.WriteJSONObject(w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid))
|
||||
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid))
|
||||
}
|
||||
|
||||
func (h *PeersHandler) deletePeer(accountID, userID string, peerID string, w http.ResponseWriter) {
|
||||
err := h.accountManager.DeletePeer(accountID, peerID, userID)
|
||||
func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
|
||||
err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete peer: %v", err)
|
||||
util.WriteError(err, w)
|
||||
log.WithContext(ctx).Errorf("failed to delete peer: %v", err)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
util.WriteJSONObject(ctx, w, emptyObject{})
|
||||
}
|
||||
|
||||
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
|
||||
func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
if len(peerID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodDelete:
|
||||
h.deletePeer(account.Id, user.Id, peerID, w)
|
||||
h.deletePeer(r.Context(), account.Id, user.Id, peerID, w)
|
||||
return
|
||||
case http.MethodPut:
|
||||
h.updatePeer(account, user, peerID, w, r)
|
||||
h.updatePeer(r.Context(), account, user, peerID, w, r)
|
||||
return
|
||||
case http.MethodGet:
|
||||
h.getPeer(account, peerID, user.Id, w)
|
||||
h.getPeer(r.Context(), account, peerID, user.Id, w)
|
||||
return
|
||||
default:
|
||||
util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllPeers returns a list of all peers associated with a provided account
|
||||
func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
peers, err := h.accountManager.GetPeers(account.Id, user.Id)
|
||||
peers, err := h.accountManager.GetPeers(r.Context(), account.Id, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -188,34 +189,34 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
for _, peer := range peers {
|
||||
peerToReturn, err := h.checkPeerStatus(peer)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
|
||||
|
||||
accessiblePeerNumbers, _ := h.accessiblePeersNumber(account, peer.ID)
|
||||
accessiblePeerNumbers, _ := h.accessiblePeersNumber(r.Context(), account, peer.ID)
|
||||
|
||||
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers))
|
||||
}
|
||||
|
||||
validPeersMap, err := h.accountManager.GetValidatedPeers(account)
|
||||
if err != nil {
|
||||
log.Errorf("failed to list appreoved peers: %v", err)
|
||||
util.WriteError(fmt.Errorf("internal error"), w)
|
||||
log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
h.setApprovalRequiredFlag(respBody, validPeersMap)
|
||||
|
||||
util.WriteJSONObject(w, respBody)
|
||||
util.WriteJSONObject(r.Context(), w, respBody)
|
||||
}
|
||||
|
||||
func (h *PeersHandler) accessiblePeersNumber(account *server.Account, peerID string) (int, error) {
|
||||
func (h *PeersHandler) accessiblePeersNumber(ctx context.Context, account *server.Account, peerID string) (int, error) {
|
||||
validatedPeersMap, err := h.accountManager.GetValidatedPeers(account)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validatedPeersMap)
|
||||
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validatedPeersMap)
|
||||
return len(netMap.Peers) + len(netMap.OfflinePeers), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
@@ -29,7 +30,7 @@ const noUpdateChannelTestPeerID = "no-update-channel"
|
||||
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
return &PeersHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
UpdatePeerFunc: func(accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
var p *nbpeer.Peer
|
||||
for _, peer := range peers {
|
||||
if update.ID == peer.ID {
|
||||
@@ -42,7 +43,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
p.Name = update.Name
|
||||
return p, nil
|
||||
},
|
||||
GetPeerFunc: func(accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
||||
GetPeerFunc: func(_ context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
||||
var p *nbpeer.Peer
|
||||
for _, peer := range peers {
|
||||
if peerID == peer.ID {
|
||||
@@ -52,13 +53,13 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
}
|
||||
return p, nil
|
||||
},
|
||||
GetPeersFunc: func(accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
GetPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
return peers, nil
|
||||
},
|
||||
GetDNSDomainFunc: func() string {
|
||||
return "netbird.selfhosted"
|
||||
},
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
user := server.NewAdminUser("test_user")
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
|
||||
@@ -35,15 +35,15 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *
|
||||
// GetAllPolicies list for the account
|
||||
func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountPolicies, err := h.accountManager.ListPolicies(account.Id, user.Id)
|
||||
accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -51,28 +51,28 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
|
||||
for _, policy := range accountPolicies {
|
||||
resp := toPolicyResponse(account, policy)
|
||||
if len(resp.Rules) == 0 {
|
||||
util.WriteError(status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||
return
|
||||
}
|
||||
policies = append(policies, resp)
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, policies)
|
||||
util.WriteJSONObject(r.Context(), w, policies)
|
||||
}
|
||||
|
||||
// UpdatePolicy handles update to a policy identified by a given ID
|
||||
func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
policyID := vars["policyId"]
|
||||
if len(policyID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
if policyIdx < 0 {
|
||||
util.WriteError(status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -94,9 +94,9 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
// CreatePolicy handles policy creation request
|
||||
func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -118,12 +118,12 @@ func (h *Policies) savePolicy(
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "policy name shouldn't be empty"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy name shouldn't be empty"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Rules) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "policy rules shouldn't be empty"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy rules shouldn't be empty"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -137,31 +137,31 @@ func (h *Policies) savePolicy(
|
||||
Enabled: req.Enabled,
|
||||
Description: req.Description,
|
||||
}
|
||||
for _, r := range req.Rules {
|
||||
for _, rule := range req.Rules {
|
||||
pr := server.PolicyRule{
|
||||
ID: policyID, //TODO: when policy can contain multiple rules, need refactor
|
||||
Name: r.Name,
|
||||
Destinations: groupMinimumsToStrings(account, r.Destinations),
|
||||
Sources: groupMinimumsToStrings(account, r.Sources),
|
||||
Bidirectional: r.Bidirectional,
|
||||
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
|
||||
Name: rule.Name,
|
||||
Destinations: groupMinimumsToStrings(account, rule.Destinations),
|
||||
Sources: groupMinimumsToStrings(account, rule.Sources),
|
||||
Bidirectional: rule.Bidirectional,
|
||||
}
|
||||
|
||||
pr.Enabled = r.Enabled
|
||||
if r.Description != nil {
|
||||
pr.Description = *r.Description
|
||||
pr.Enabled = rule.Enabled
|
||||
if rule.Description != nil {
|
||||
pr.Description = *rule.Description
|
||||
}
|
||||
|
||||
switch r.Action {
|
||||
switch rule.Action {
|
||||
case api.PolicyRuleUpdateActionAccept:
|
||||
pr.Action = server.PolicyTrafficActionAccept
|
||||
case api.PolicyRuleUpdateActionDrop:
|
||||
pr.Action = server.PolicyTrafficActionDrop
|
||||
default:
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "unknown action type"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown action type"), w)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Protocol {
|
||||
switch rule.Protocol {
|
||||
case api.PolicyRuleUpdateProtocolAll:
|
||||
pr.Protocol = server.PolicyRuleProtocolALL
|
||||
case api.PolicyRuleUpdateProtocolTcp:
|
||||
@@ -171,14 +171,14 @@ func (h *Policies) savePolicy(
|
||||
case api.PolicyRuleUpdateProtocolIcmp:
|
||||
pr.Protocol = server.PolicyRuleProtocolICMP
|
||||
default:
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "unknown protocol type: %v", r.Protocol), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Ports != nil && len(*r.Ports) != 0 {
|
||||
for _, v := range *r.Ports {
|
||||
if rule.Ports != nil && len(*rule.Ports) != 0 {
|
||||
for _, v := range *rule.Ports {
|
||||
if port, err := strconv.Atoi(v); err != nil || port < 1 || port > 65535 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w)
|
||||
return
|
||||
}
|
||||
pr.Ports = append(pr.Ports, v)
|
||||
@@ -189,16 +189,16 @@ func (h *Policies) savePolicy(
|
||||
switch pr.Protocol {
|
||||
case server.PolicyRuleProtocolALL, server.PolicyRuleProtocolICMP:
|
||||
if len(pr.Ports) != 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w)
|
||||
return
|
||||
}
|
||||
if !pr.Bidirectional {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
|
||||
return
|
||||
}
|
||||
case server.PolicyRuleProtocolTCP, server.PolicyRuleProtocolUDP:
|
||||
if !pr.Bidirectional && len(pr.Ports) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -210,26 +210,26 @@ func (h *Policies) savePolicy(
|
||||
policy.SourcePostureChecks = sourcePostureChecksToStrings(account, *req.SourcePostureChecks)
|
||||
}
|
||||
|
||||
if err := h.accountManager.SavePolicy(account.Id, user.Id, &policy); err != nil {
|
||||
util.WriteError(err, w)
|
||||
if err := h.accountManager.SavePolicy(r.Context(), account.Id, user.Id, &policy); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toPolicyResponse(account, &policy)
|
||||
if len(resp.Rules) == 0 {
|
||||
util.WriteError(status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, resp)
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
// DeletePolicy handles policy deletion request
|
||||
func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
aID := account.Id
|
||||
@@ -237,24 +237,24 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
policyID := vars["policyId"]
|
||||
if len(policyID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.accountManager.DeletePolicy(aID, policyID, user.Id); err != nil {
|
||||
util.WriteError(err, w)
|
||||
if err = h.accountManager.DeletePolicy(r.Context(), aID, policyID, user.Id); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
||||
}
|
||||
|
||||
// GetPolicy handles a group Get request identified by ID
|
||||
func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -263,25 +263,25 @@ func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
policyID := vars["policyId"]
|
||||
if len(policyID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
policy, err := h.accountManager.GetPolicy(account.Id, policyID, user.Id)
|
||||
policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toPolicyResponse(account, policy)
|
||||
if len(resp.Rules) == 0 {
|
||||
util.WriteError(status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, resp)
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
default:
|
||||
util.WriteError(status.Errorf(status.NotFound, "method not found"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "method not found"), w)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -30,21 +31,21 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
|
||||
}
|
||||
return &Policies{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetPolicyFunc: func(_, policyID, _ string) (*server.Policy, error) {
|
||||
GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*server.Policy, error) {
|
||||
policy, ok := testPolicies[policyID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "policy not found")
|
||||
}
|
||||
return policy, nil
|
||||
},
|
||||
SavePolicyFunc: func(_, _ string, policy *server.Policy) error {
|
||||
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) error {
|
||||
if !strings.HasPrefix(policy.ID, "id-") {
|
||||
policy.ID = "id-was-set"
|
||||
policy.Rules[0].ID = "id-was-set"
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
user := server.NewAdminUser("test_user")
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
|
||||
@@ -37,15 +37,15 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa
|
||||
// GetAllPostureChecks list for the account
|
||||
func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) {
|
||||
claims := p.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountPostureChecks, err := p.accountManager.ListPostureChecks(account.Id, user.Id)
|
||||
accountPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), account.Id, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -54,22 +54,22 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt
|
||||
postureChecks = append(postureChecks, postureCheck.ToAPIResponse())
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, postureChecks)
|
||||
util.WriteJSONObject(r.Context(), w, postureChecks)
|
||||
}
|
||||
|
||||
// UpdatePostureCheck handles update to a posture check identified by a given ID
|
||||
func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
claims := p.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
postureChecksID := vars["postureCheckId"]
|
||||
if len(postureChecksID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -81,7 +81,7 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http
|
||||
}
|
||||
}
|
||||
if postureChecksIdx < 0 {
|
||||
util.WriteError(status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -91,9 +91,9 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http
|
||||
// CreatePostureCheck handles posture check creation request
|
||||
func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
claims := p.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -103,50 +103,50 @@ func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http
|
||||
// GetPostureCheck handles a posture check Get request identified by ID
|
||||
func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
claims := p.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
postureChecksID := vars["postureCheckId"]
|
||||
if len(postureChecksID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
postureChecks, err := p.accountManager.GetPostureChecks(account.Id, postureChecksID, user.Id)
|
||||
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), account.Id, postureChecksID, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, postureChecks.ToAPIResponse())
|
||||
util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse())
|
||||
}
|
||||
|
||||
// DeletePostureCheck handles posture check deletion request
|
||||
func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
claims := p.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
postureChecksID := vars["postureCheckId"]
|
||||
if len(postureChecksID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err = p.accountManager.DeletePostureChecks(account.Id, postureChecksID, user.Id); err != nil {
|
||||
util.WriteError(err, w)
|
||||
if err = p.accountManager.DeletePostureChecks(r.Context(), account.Id, postureChecksID, user.Id); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
||||
}
|
||||
|
||||
// savePostureChecks handles posture checks create and update
|
||||
@@ -169,7 +169,7 @@ func (p *PostureChecksHandler) savePostureChecks(
|
||||
|
||||
if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil {
|
||||
if p.geolocationManager == nil {
|
||||
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
|
||||
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
|
||||
"Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w)
|
||||
return
|
||||
}
|
||||
@@ -177,14 +177,14 @@ func (p *PostureChecksHandler) savePostureChecks(
|
||||
|
||||
postureChecks, err := posture.NewChecksFromAPIPostureCheckUpdate(req, postureChecksID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := p.accountManager.SavePostureChecks(account.Id, user.Id, postureChecks); err != nil {
|
||||
util.WriteError(err, w)
|
||||
if err := p.accountManager.SavePostureChecks(r.Context(), account.Id, user.Id, postureChecks); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, postureChecks.ToAPIResponse())
|
||||
util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse())
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -33,14 +34,14 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
|
||||
|
||||
return &PostureChecksHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetPostureChecksFunc: func(accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||
GetPostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||
p, ok := testPostureChecks[postureChecksID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "posture checks not found")
|
||||
}
|
||||
return p, nil
|
||||
},
|
||||
SavePostureChecksFunc: func(accountID, userID string, postureChecks *posture.Checks) error {
|
||||
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
||||
postureChecks.ID = "postureCheck"
|
||||
testPostureChecks[postureChecks.ID] = postureChecks
|
||||
|
||||
@@ -50,7 +51,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
|
||||
|
||||
return nil
|
||||
},
|
||||
DeletePostureChecksFunc: func(accountID, postureChecksID, userID string) error {
|
||||
DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error {
|
||||
_, ok := testPostureChecks[postureChecksID]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "posture checks not found")
|
||||
@@ -59,14 +60,14 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
|
||||
|
||||
return nil
|
||||
},
|
||||
ListPostureChecksFunc: func(accountID, userID string) ([]*posture.Checks, error) {
|
||||
ListPostureChecksFunc: func(_ context.Context, accountID, userID string) ([]*posture.Checks, error) {
|
||||
accountPostureChecks := make([]*posture.Checks, len(testPostureChecks))
|
||||
for _, p := range testPostureChecks {
|
||||
accountPostureChecks = append(accountPostureChecks, p)
|
||||
}
|
||||
return accountPostureChecks, nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
user := server.NewAdminUser("test_user")
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
|
||||
@@ -43,36 +43,36 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro
|
||||
// GetAllRoutes returns the list of routes for the account
|
||||
func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routes, err := h.accountManager.ListRoutes(account.Id, user.Id)
|
||||
routes, err := h.accountManager.ListRoutes(r.Context(), account.Id, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
apiRoutes := make([]*api.Route, 0)
|
||||
for _, r := range routes {
|
||||
route, err := toRouteResponse(r)
|
||||
for _, route := range routes {
|
||||
route, err := toRouteResponse(route)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w)
|
||||
return
|
||||
}
|
||||
apiRoutes = append(apiRoutes, route)
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, apiRoutes)
|
||||
util.WriteJSONObject(r.Context(), w, apiRoutes)
|
||||
}
|
||||
|
||||
// CreateRoute handles route creation request
|
||||
func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if err := h.validateRoute(req); err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -94,7 +94,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
if req.Domains != nil {
|
||||
d, err := validateDomains(*req.Domains)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
|
||||
return
|
||||
}
|
||||
domains = d
|
||||
@@ -102,7 +102,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
} else if req.Network != nil {
|
||||
networkType, newPrefix, err = route.ParseNetwork(*req.Network)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -120,24 +120,24 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
// Do not allow non-Linux peers
|
||||
if peer := account.GetPeer(peerId); peer != nil {
|
||||
if peer.Meta.GoOS != "linux" {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
newRoute, err := h.accountManager.CreateRoute(account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute)
|
||||
newRoute, err := h.accountManager.CreateRoute(r.Context(), account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routes, err := toRouteResponse(newRoute)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, routes)
|
||||
util.WriteJSONObject(r.Context(), w, routes)
|
||||
}
|
||||
|
||||
func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) error {
|
||||
@@ -168,22 +168,22 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
|
||||
// UpdateRoute handles update to a route identified by a given ID
|
||||
func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
routeID := vars["routeId"]
|
||||
if len(routeID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id)
|
||||
_, err = h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -195,7 +195,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if err := h.validateRoute(req); err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -207,7 +207,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
// do not allow non Linux peers
|
||||
if peer := account.GetPeer(peerID); peer != nil {
|
||||
if peer.Meta.GoOS != "linux" {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -226,7 +226,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
if req.Domains != nil {
|
||||
d, err := validateDomains(*req.Domains)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
|
||||
return
|
||||
}
|
||||
newRoute.Domains = d
|
||||
@@ -234,7 +234,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
} else if req.Network != nil {
|
||||
newRoute.NetworkType, newRoute.Network, err = route.ParseNetwork(*req.Network)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -247,73 +247,73 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
newRoute.PeerGroups = *req.PeerGroups
|
||||
}
|
||||
|
||||
err = h.accountManager.SaveRoute(account.Id, user.Id, newRoute)
|
||||
err = h.accountManager.SaveRoute(r.Context(), account.Id, user.Id, newRoute)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routes, err := toRouteResponse(newRoute)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, routes)
|
||||
util.WriteJSONObject(r.Context(), w, routes)
|
||||
}
|
||||
|
||||
// DeleteRoute handles route deletion request
|
||||
func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routeID := mux.Vars(r)["routeId"]
|
||||
if len(routeID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteRoute(account.Id, route.ID(routeID), user.Id)
|
||||
err = h.accountManager.DeleteRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
||||
}
|
||||
|
||||
// GetRoute handles a route Get request identified by ID
|
||||
func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routeID := mux.Vars(r)["routeId"]
|
||||
if len(routeID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
foundRoute, err := h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id)
|
||||
foundRoute, err := h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.NotFound, "route not found"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w)
|
||||
return
|
||||
}
|
||||
|
||||
routes, err := toRouteResponse(foundRoute)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, routes)
|
||||
util.WriteJSONObject(r.Context(), w, routes)
|
||||
}
|
||||
|
||||
func toRouteResponse(serverRoute *route.Route) (*api.Route, error) {
|
||||
|
||||
@@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -89,7 +90,7 @@ var testingAccount = &server.Account{
|
||||
func initRoutesTestData() *RoutesHandler {
|
||||
return &RoutesHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetRouteFunc: func(_ string, routeID route.ID, _ string) (*route.Route, error) {
|
||||
GetRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) (*route.Route, error) {
|
||||
if routeID == existingRouteID {
|
||||
return baseExistingRoute, nil
|
||||
}
|
||||
@@ -104,7 +105,7 @@ func initRoutesTestData() *RoutesHandler {
|
||||
}
|
||||
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
|
||||
},
|
||||
CreateRouteFunc: func(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) {
|
||||
CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) {
|
||||
if peerID == notFoundPeerID {
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||
}
|
||||
@@ -126,19 +127,19 @@ func initRoutesTestData() *RoutesHandler {
|
||||
KeepRoute: keepRoute,
|
||||
}, nil
|
||||
},
|
||||
SaveRouteFunc: func(_, _ string, r *route.Route) error {
|
||||
SaveRouteFunc: func(_ context.Context, _, _ string, r *route.Route) error {
|
||||
if r.Peer == notFoundPeerID {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
DeleteRouteFunc: func(_ string, routeID route.ID, _ string) error {
|
||||
DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error {
|
||||
if routeID != existingRouteID {
|
||||
return status.Errorf(status.NotFound, "Peer with ID %s not found", routeID)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return testingAccount, testingAccount.Users["test_user"], nil
|
||||
},
|
||||
},
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
@@ -34,9 +35,9 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg)
|
||||
// CreateSetupKey is a POST requests that creates a new SetupKey
|
||||
func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -48,13 +49,13 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "setup key name shouldn't be empty"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name shouldn't be empty"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if !(server.SetupKeyType(req.Type) == server.SetupKeyReusable ||
|
||||
server.SetupKeyType(req.Type) == server.SetupKeyOneOff) {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "unknown setup key type %s", req.Type), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown setup key type %s", req.Type), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -63,7 +64,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
day := time.Hour * 24
|
||||
year := day * 365
|
||||
if expiresIn < day || expiresIn > year {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "expiresIn should be between 1 day and 365 days"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "expiresIn should be between 1 day and 365 days"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -75,54 +76,54 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
if req.Ephemeral != nil {
|
||||
ephemeral = *req.Ephemeral
|
||||
}
|
||||
setupKey, err := h.accountManager.CreateSetupKey(account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
|
||||
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
|
||||
req.AutoGroups, req.UsageLimit, user.Id, ephemeral)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
writeSuccess(w, setupKey)
|
||||
writeSuccess(r.Context(), w, setupKey)
|
||||
}
|
||||
|
||||
// GetSetupKey is a GET request to get a SetupKey by ID
|
||||
func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
keyID := vars["keyId"]
|
||||
if len(keyID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
key, err := h.accountManager.GetSetupKey(account.Id, user.Id, keyID)
|
||||
key, err := h.accountManager.GetSetupKey(r.Context(), account.Id, user.Id, keyID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
writeSuccess(w, key)
|
||||
writeSuccess(r.Context(), w, key)
|
||||
}
|
||||
|
||||
// UpdateSetupKey is a PUT request to update server.SetupKey
|
||||
func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
keyID := vars["keyId"]
|
||||
if len(keyID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -134,12 +135,12 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
|
||||
return
|
||||
}
|
||||
|
||||
if req.AutoGroups == nil {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -149,26 +150,26 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
newKey.Name = req.Name
|
||||
newKey.Id = keyID
|
||||
|
||||
newKey, err = h.accountManager.SaveSetupKey(account.Id, newKey, user.Id)
|
||||
newKey, err = h.accountManager.SaveSetupKey(r.Context(), account.Id, newKey, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
writeSuccess(w, newKey)
|
||||
writeSuccess(r.Context(), w, newKey)
|
||||
}
|
||||
|
||||
// GetAllSetupKeys is a GET request that returns a list of SetupKey
|
||||
func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
setupKeys, err := h.accountManager.ListSetupKeys(account.Id, user.Id)
|
||||
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), account.Id, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -177,15 +178,15 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques
|
||||
apiSetupKeys = append(apiSetupKeys, toResponseBody(key))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, apiSetupKeys)
|
||||
util.WriteJSONObject(r.Context(), w, apiSetupKeys)
|
||||
}
|
||||
|
||||
func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
|
||||
func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
err := json.NewEncoder(w).Encode(toResponseBody(key))
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -33,7 +34,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
||||
) *SetupKeysHandler {
|
||||
return &SetupKeysHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return &server.Account{
|
||||
Id: testAccountID,
|
||||
Domain: "hotmail.com",
|
||||
@@ -49,7 +50,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
||||
},
|
||||
}, user, nil
|
||||
},
|
||||
CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
|
||||
CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
|
||||
_ int, _ string, ephemeral bool,
|
||||
) (*server.SetupKey, error) {
|
||||
if keyName == newKey.Name || typ != newKey.Type {
|
||||
@@ -59,7 +60,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
||||
}
|
||||
return nil, fmt.Errorf("failed creating setup key")
|
||||
},
|
||||
GetSetupKeyFunc: func(accountID, userID, keyID string) (*server.SetupKey, error) {
|
||||
GetSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) (*server.SetupKey, error) {
|
||||
switch keyID {
|
||||
case defaultKey.Id:
|
||||
return defaultKey, nil
|
||||
@@ -70,14 +71,14 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
||||
}
|
||||
},
|
||||
|
||||
SaveSetupKeyFunc: func(accountID string, key *server.SetupKey, _ string) (*server.SetupKey, error) {
|
||||
SaveSetupKeyFunc: func(_ context.Context, accountID string, key *server.SetupKey, _ string) (*server.SetupKey, error) {
|
||||
if key.Id == updatedSetupKey.Id {
|
||||
return updatedSetupKey, nil
|
||||
}
|
||||
return nil, status.Errorf(status.NotFound, "key %s not found", key.Id)
|
||||
},
|
||||
|
||||
ListSetupKeysFunc: func(accountID, userID string) ([]*server.SetupKey, error) {
|
||||
ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*server.SetupKey, error) {
|
||||
return []*server.SetupKey{defaultKey}, nil
|
||||
},
|
||||
},
|
||||
|
||||
@@ -41,22 +41,22 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
userID := vars["userId"]
|
||||
if len(userID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
existingUser, ok := account.Users[userID]
|
||||
if !ok {
|
||||
util.WriteError(status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -74,11 +74,11 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userRole := server.StrRoleToUserRole(req.Role)
|
||||
if userRole == server.UserRoleUnknown {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user role"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user role"), w)
|
||||
return
|
||||
}
|
||||
|
||||
newUser, err := h.accountManager.SaveUser(account.Id, user.Id, &server.User{
|
||||
newUser, err := h.accountManager.SaveUser(r.Context(), account.Id, user.Id, &server.User{
|
||||
Id: userID,
|
||||
Role: userRole,
|
||||
AutoGroups: req.AutoGroups,
|
||||
@@ -88,10 +88,10 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId))
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
|
||||
}
|
||||
|
||||
// DeleteUser is a DELETE request to delete a user
|
||||
@@ -102,26 +102,26 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteUser(account.Id, user.Id, targetUserID)
|
||||
err = h.accountManager.DeleteUser(r.Context(), account.Id, user.Id, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
||||
}
|
||||
|
||||
// CreateUser creates a User in the system with a status "invited" (effectively this is a user invite).
|
||||
@@ -132,9 +132,9 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -146,7 +146,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -160,7 +160,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
||||
name = *req.Name
|
||||
}
|
||||
|
||||
newUser, err := h.accountManager.CreateUser(account.Id, user.Id, &server.UserInfo{
|
||||
newUser, err := h.accountManager.CreateUser(r.Context(), account.Id, user.Id, &server.UserInfo{
|
||||
Email: email,
|
||||
Name: name,
|
||||
Role: req.Role,
|
||||
@@ -169,10 +169,10 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
||||
Issued: server.UserIssuedAPI,
|
||||
})
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId))
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
|
||||
}
|
||||
|
||||
// GetAllUsers returns a list of users of the account this user belongs to.
|
||||
@@ -184,42 +184,42 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := h.accountManager.GetUsersFromAccount(account.Id, user.Id)
|
||||
data, err := h.accountManager.GetUsersFromAccount(r.Context(), account.Id, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
serviceUser := r.URL.Query().Get("service_user")
|
||||
|
||||
users := make([]*api.User, 0)
|
||||
for _, r := range data {
|
||||
if r.NonDeletable {
|
||||
for _, d := range data {
|
||||
if d.NonDeletable {
|
||||
continue
|
||||
}
|
||||
if serviceUser == "" {
|
||||
users = append(users, toUserResponse(r, claims.UserId))
|
||||
users = append(users, toUserResponse(d, claims.UserId))
|
||||
continue
|
||||
}
|
||||
|
||||
includeServiceUser, err := strconv.ParseBool(serviceUser)
|
||||
log.Debugf("Should include service user: %v", includeServiceUser)
|
||||
log.WithContext(r.Context()).Debugf("Should include service user: %v", includeServiceUser)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w)
|
||||
return
|
||||
}
|
||||
if includeServiceUser == r.IsServiceUser {
|
||||
users = append(users, toUserResponse(r, claims.UserId))
|
||||
if includeServiceUser == d.IsServiceUser {
|
||||
users = append(users, toUserResponse(d, claims.UserId))
|
||||
}
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, users)
|
||||
util.WriteJSONObject(r.Context(), w, users)
|
||||
}
|
||||
|
||||
// InviteUser resend invitations to users who haven't activated their accounts,
|
||||
@@ -231,26 +231,26 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.InviteUser(account.Id, user.Id, targetUserID)
|
||||
err = h.accountManager.InviteUser(r.Context(), account.Id, user.Id, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
||||
}
|
||||
|
||||
func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
|
||||
|
||||
@@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -63,10 +64,10 @@ var usersTestAccount = &server.Account{
|
||||
func initUsersTestData() *UsersHandler {
|
||||
return &UsersHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return usersTestAccount, usersTestAccount.Users[claims.UserId], nil
|
||||
},
|
||||
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
|
||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
|
||||
users := make([]*server.UserInfo, 0)
|
||||
for _, v := range usersTestAccount.Users {
|
||||
users = append(users, &server.UserInfo{
|
||||
@@ -81,13 +82,13 @@ func initUsersTestData() *UsersHandler {
|
||||
}
|
||||
return users, nil
|
||||
},
|
||||
CreateUserFunc: func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) {
|
||||
CreateUserFunc: func(_ context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) {
|
||||
if userID != existingUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
|
||||
}
|
||||
return key, nil
|
||||
},
|
||||
DeleteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error {
|
||||
DeleteUserFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) error {
|
||||
if targetUserID == notFoundUserID {
|
||||
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
|
||||
}
|
||||
@@ -96,7 +97,7 @@ func initUsersTestData() *UsersHandler {
|
||||
}
|
||||
return nil
|
||||
},
|
||||
SaveUserFunc: func(accountID, userID string, update *server.User) (*server.UserInfo, error) {
|
||||
SaveUserFunc: func(_ context.Context, accountID, userID string, update *server.User) (*server.UserInfo, error) {
|
||||
if update.Id == notFoundUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", update.Id)
|
||||
}
|
||||
@@ -111,7 +112,7 @@ func initUsersTestData() *UsersHandler {
|
||||
}
|
||||
return info, nil
|
||||
},
|
||||
InviteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error {
|
||||
InviteUserFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) error {
|
||||
if initiatorUserID != existingUserID {
|
||||
return status.Errorf(status.NotFound, "user with ID %s does not exists", initiatorUserID)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -19,12 +20,12 @@ type ErrorResponse struct {
|
||||
}
|
||||
|
||||
// WriteJSONObject simply writes object to the HTTP response in JSON format
|
||||
func WriteJSONObject(w http.ResponseWriter, obj interface{}) {
|
||||
func WriteJSONObject(ctx context.Context, w http.ResponseWriter, obj interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
err := json.NewEncoder(w).Encode(obj)
|
||||
if err != nil {
|
||||
WriteError(err, w)
|
||||
WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -76,8 +77,8 @@ func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) {
|
||||
|
||||
// WriteError converts an error to an JSON error response.
|
||||
// If it is known internal error of type server.Error then it sets the messages from the error, a generic message otherwise
|
||||
func WriteError(err error, w http.ResponseWriter) {
|
||||
log.Errorf("got a handler error: %s", err.Error())
|
||||
func WriteError(ctx context.Context, err error, w http.ResponseWriter) {
|
||||
log.WithContext(ctx).Errorf("got a handler error: %s", err.Error())
|
||||
errStatus, ok := status.FromError(err)
|
||||
httpStatus := http.StatusInternalServerError
|
||||
msg := "internal server error"
|
||||
@@ -106,7 +107,7 @@ func WriteError(err error, w http.ResponseWriter) {
|
||||
msg = strings.ToLower(err.Error())
|
||||
} else {
|
||||
unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", err.Error())
|
||||
log.Error(unhandledMSG)
|
||||
log.WithContext(ctx).Error(unhandledMSG)
|
||||
}
|
||||
|
||||
WriteErrorResponse(msg, httpStatus, w)
|
||||
|
||||
@@ -183,7 +183,7 @@ func (c *Auth0Credentials) jwtStillValid() bool {
|
||||
}
|
||||
|
||||
// requestJWTToken performs request to get jwt token
|
||||
func (c *Auth0Credentials) requestJWTToken() (*http.Response, error) {
|
||||
func (c *Auth0Credentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
||||
var res *http.Response
|
||||
reqURL := c.clientConfig.AuthIssuer + "/oauth/token"
|
||||
|
||||
@@ -200,7 +200,7 @@ func (c *Auth0Credentials) requestJWTToken() (*http.Response, error) {
|
||||
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
log.Debug("requesting new jwt token for idp manager")
|
||||
log.WithContext(ctx).Debug("requesting new jwt token for idp manager")
|
||||
|
||||
res, err = c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -247,7 +247,7 @@ func (c *Auth0Credentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTTo
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the Auth0 Management API
|
||||
func (c *Auth0Credentials) Authenticate() (JWTToken, error) {
|
||||
func (c *Auth0Credentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
@@ -260,14 +260,14 @@ func (c *Auth0Credentials) Authenticate() (JWTToken, error) {
|
||||
return c.jwtToken, nil
|
||||
}
|
||||
|
||||
res, err := c.requestJWTToken()
|
||||
res, err := c.requestJWTToken(ctx)
|
||||
if err != nil {
|
||||
return c.jwtToken, err
|
||||
}
|
||||
defer func() {
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf("error while closing get jwt token response body: %v", err)
|
||||
log.WithContext(ctx).Errorf("error while closing get jwt token response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -301,8 +301,8 @@ func requestByUserIDURL(authIssuer, userID string) string {
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile. Calls Auth0 API.
|
||||
func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
func (am *Auth0Manager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -353,7 +353,7 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debugf("returned user batch for accountID %s on page %d, batch length %d", accountID, page, len(batch))
|
||||
log.WithContext(ctx).Debugf("returned user batch for accountID %s on page %d, batch length %d", accountID, page, len(batch))
|
||||
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
@@ -365,7 +365,7 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
}
|
||||
|
||||
if len(batch) == 0 || len(batch) < resultsPerPage {
|
||||
log.Debugf("finished loading users for accountID %s", accountID)
|
||||
log.WithContext(ctx).Debugf("finished loading users for accountID %s", accountID)
|
||||
return list, nil
|
||||
}
|
||||
}
|
||||
@@ -374,8 +374,8 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from auth0 via ID
|
||||
func (am *Auth0Manager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
func (am *Auth0Manager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -414,7 +414,7 @@ func (am *Auth0Manager) GetUserDataByID(userID string, appMetadata AppMetadata)
|
||||
defer func() {
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf("error while closing update user app metadata response body: %v", err)
|
||||
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -426,9 +426,9 @@ func (am *Auth0Manager) GetUserDataByID(userID string, appMetadata AppMetadata)
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userId and metadata map
|
||||
func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
|
||||
func (am *Auth0Manager) UpdateUserAppMetadata(ctx context.Context, userID string, appMetadata AppMetadata) error {
|
||||
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -449,7 +449,7 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
log.Debugf("updating IdP metadata for user %s", userID)
|
||||
log.WithContext(ctx).Debugf("updating IdP metadata for user %s", userID)
|
||||
|
||||
res, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -466,7 +466,7 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
|
||||
defer func() {
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf("error while closing update user app metadata response body: %v", err)
|
||||
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -530,9 +530,9 @@ func buildUserExportRequest() (string, error) {
|
||||
}
|
||||
|
||||
func (am *Auth0Manager) createRequest(
|
||||
method string, endpoint string, body io.Reader,
|
||||
ctx context.Context, method string, endpoint string, body io.Reader,
|
||||
) (*http.Request, error) {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -548,8 +548,8 @@ func (am *Auth0Manager) createRequest(
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) {
|
||||
req, err := am.createRequest("POST", endpoint, strings.NewReader(payloadStr))
|
||||
func (am *Auth0Manager) createPostRequest(ctx context.Context, endpoint string, payloadStr string) (*http.Request, error) {
|
||||
req, err := am.createRequest(ctx, "POST", endpoint, strings.NewReader(payloadStr))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -560,20 +560,20 @@ func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
func (am *Auth0Manager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
payloadString, err := buildUserExportRequest()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
exportJobReq, err := am.createPostRequest("/api/v2/jobs/users-exports", payloadString)
|
||||
exportJobReq, err := am.createPostRequest(ctx, "/api/v2/jobs/users-exports", payloadString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
jobResp, err := am.httpClient.Do(exportJobReq)
|
||||
if err != nil {
|
||||
log.Debugf("Couldn't get job response %v", err)
|
||||
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
@@ -583,7 +583,7 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
defer func() {
|
||||
err = jobResp.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf("error while closing update user app metadata response body: %v", err)
|
||||
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
|
||||
}
|
||||
}()
|
||||
if jobResp.StatusCode != 200 {
|
||||
@@ -597,13 +597,13 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
|
||||
body, err := io.ReadAll(jobResp.Body)
|
||||
if err != nil {
|
||||
log.Debugf("Couldn't read export job response; %v", err)
|
||||
log.WithContext(ctx).Debugf("Couldn't read export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.helper.Unmarshal(body, &exportJobResp)
|
||||
if err != nil {
|
||||
log.Debugf("Couldn't unmarshal export job response; %v", err)
|
||||
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -614,16 +614,16 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
return nil, fmt.Errorf("couldn't get an batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
|
||||
}
|
||||
|
||||
log.Debugf("batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
|
||||
log.WithContext(ctx).Debugf("batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
|
||||
|
||||
done, downloadLink, err := am.checkExportJobStatus(exportJobResp.ID)
|
||||
done, downloadLink, err := am.checkExportJobStatus(ctx, exportJobResp.ID)
|
||||
if err != nil {
|
||||
log.Debugf("Failed at getting status checks from exportJob; %v", err)
|
||||
log.WithContext(ctx).Debugf("Failed at getting status checks from exportJob; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if done {
|
||||
return am.downloadProfileExport(downloadLink)
|
||||
return am.downloadProfileExport(ctx, downloadLink)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed extracting user profiles from auth0")
|
||||
@@ -632,13 +632,13 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
// GetUserByEmail searches users with a given email. If no users have been found, this function returns an empty list.
|
||||
// This function can return multiple users. This is due to the Auth0 internals - there could be multiple users with
|
||||
// the same email but different connections that are considered as separate accounts (e.g., Google and username/password).
|
||||
func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
func (am *Auth0Manager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reqURL := am.authIssuer + "/api/v2/users-by-email?email=" + url.QueryEscape(email)
|
||||
body, err := doGetReq(am.httpClient, reqURL, jwtToken.AccessToken)
|
||||
body, err := doGetReq(ctx, am.httpClient, reqURL, jwtToken.AccessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -651,7 +651,7 @@ func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
|
||||
err = am.helper.Unmarshal(body, &userResp)
|
||||
if err != nil {
|
||||
log.Debugf("Couldn't unmarshal export job response; %v", err)
|
||||
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -659,13 +659,13 @@ func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in Auth0 Idp and sends an invite
|
||||
func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
func (am *Auth0Manager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
|
||||
payloadString, err := buildCreateUserRequestPayload(email, name, accountID, invitedByEmail)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := am.createPostRequest("/api/v2/users", payloadString)
|
||||
req, err := am.createPostRequest(ctx, "/api/v2/users", payloadString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -676,7 +676,7 @@ func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Debugf("Couldn't get job response %v", err)
|
||||
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
@@ -686,7 +686,7 @@ func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf("error while closing create user response body: %v", err)
|
||||
log.WithContext(ctx).Errorf("error while closing create user response body: %v", err)
|
||||
}
|
||||
}()
|
||||
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
|
||||
@@ -700,13 +700,13 @@ func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Debugf("Couldn't read export job response; %v", err)
|
||||
log.WithContext(ctx).Debugf("Couldn't read export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.helper.Unmarshal(body, &createResp)
|
||||
if err != nil {
|
||||
log.Debugf("Couldn't unmarshal export job response; %v", err)
|
||||
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -714,14 +714,14 @@ func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string
|
||||
return nil, fmt.Errorf("couldn't create user: response %v", resp)
|
||||
}
|
||||
|
||||
log.Debugf("created user %s in account %s", createResp.ID, accountID)
|
||||
log.WithContext(ctx).Debugf("created user %s in account %s", createResp.ID, accountID)
|
||||
|
||||
return &createResp, nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (am *Auth0Manager) InviteUserByID(userID string) error {
|
||||
func (am *Auth0Manager) InviteUserByID(ctx context.Context, userID string) error {
|
||||
userVerificationReq := userVerificationJobRequest{
|
||||
UserID: userID,
|
||||
}
|
||||
@@ -731,14 +731,14 @@ func (am *Auth0Manager) InviteUserByID(userID string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := am.createPostRequest("/api/v2/jobs/verification-email", string(payload))
|
||||
req, err := am.createPostRequest(ctx, "/api/v2/jobs/verification-email", string(payload))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Debugf("Couldn't get job response %v", err)
|
||||
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
@@ -748,7 +748,7 @@ func (am *Auth0Manager) InviteUserByID(userID string) error {
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf("error while closing invite user response body: %v", err)
|
||||
log.WithContext(ctx).Errorf("error while closing invite user response body: %v", err)
|
||||
}
|
||||
}()
|
||||
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
|
||||
@@ -762,15 +762,15 @@ func (am *Auth0Manager) InviteUserByID(userID string) error {
|
||||
}
|
||||
|
||||
// DeleteUser from Auth0
|
||||
func (am *Auth0Manager) DeleteUser(userID string) error {
|
||||
req, err := am.createRequest(http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil)
|
||||
func (am *Auth0Manager) DeleteUser(ctx context.Context, userID string) error {
|
||||
req, err := am.createRequest(ctx, http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Debugf("execute delete request: %v", err)
|
||||
log.WithContext(ctx).Debugf("execute delete request: %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
@@ -780,7 +780,7 @@ func (am *Auth0Manager) DeleteUser(userID string) error {
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf("close delete request body: %v", err)
|
||||
log.WithContext(ctx).Errorf("close delete request body: %v", err)
|
||||
}
|
||||
}()
|
||||
if resp.StatusCode != 204 {
|
||||
@@ -795,20 +795,20 @@ func (am *Auth0Manager) DeleteUser(userID string) error {
|
||||
|
||||
// GetAllConnections returns detailed list of all connections filtered by given params.
|
||||
// Note this method is not part of the IDP Manager interface as this is Auth0 specific.
|
||||
func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, error) {
|
||||
func (am *Auth0Manager) GetAllConnections(ctx context.Context, strategy []string) ([]Connection, error) {
|
||||
var connections []Connection
|
||||
|
||||
q := make(url.Values)
|
||||
q.Set("strategy", strings.Join(strategy, ","))
|
||||
|
||||
req, err := am.createRequest(http.MethodGet, "/api/v2/connections?"+q.Encode(), nil)
|
||||
req, err := am.createRequest(ctx, http.MethodGet, "/api/v2/connections?"+q.Encode(), nil)
|
||||
if err != nil {
|
||||
return connections, err
|
||||
}
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Debugf("execute get connections request: %v", err)
|
||||
log.WithContext(ctx).Debugf("execute get connections request: %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
@@ -818,7 +818,7 @@ func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, erro
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf("close get connections request body: %v", err)
|
||||
log.WithContext(ctx).Errorf("close get connections request body: %v", err)
|
||||
}
|
||||
}()
|
||||
if resp.StatusCode != 200 {
|
||||
@@ -830,13 +830,13 @@ func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, erro
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Debugf("Couldn't read get connections response; %v", err)
|
||||
log.WithContext(ctx).Debugf("Couldn't read get connections response; %v", err)
|
||||
return connections, err
|
||||
}
|
||||
|
||||
err = am.helper.Unmarshal(body, &connections)
|
||||
if err != nil {
|
||||
log.Debugf("Couldn't unmarshal get connection response; %v", err)
|
||||
log.WithContext(ctx).Debugf("Couldn't unmarshal get connection response; %v", err)
|
||||
return connections, err
|
||||
}
|
||||
|
||||
@@ -845,23 +845,23 @@ func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, erro
|
||||
|
||||
// checkExportJobStatus checks the status of the job created at CreateExportUsersJob.
|
||||
// If the status is "completed", then return the downloadLink
|
||||
func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
|
||||
func (am *Auth0Manager) checkExportJobStatus(ctx context.Context, jobID string) (bool, string, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 90*time.Second)
|
||||
defer cancel()
|
||||
retry := time.NewTicker(10 * time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Debugf("Export job status stopped...\n")
|
||||
log.WithContext(ctx).Debugf("Export job status stopped...\n")
|
||||
return false, "", ctx.Err()
|
||||
case <-retry.C:
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
statusURL := am.authIssuer + "/api/v2/jobs/" + jobID
|
||||
body, err := doGetReq(am.httpClient, statusURL, jwtToken.AccessToken)
|
||||
body, err := doGetReq(ctx, am.httpClient, statusURL, jwtToken.AccessToken)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
@@ -872,7 +872,7 @@ func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error)
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
log.Debugf("current export job status is %v", status.Status)
|
||||
log.WithContext(ctx).Debugf("current export job status is %v", status.Status)
|
||||
|
||||
if status.Status != "completed" {
|
||||
continue
|
||||
@@ -884,8 +884,8 @@ func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error)
|
||||
}
|
||||
|
||||
// downloadProfileExport downloads user profiles from auth0 batch job
|
||||
func (am *Auth0Manager) downloadProfileExport(location string) (map[string][]*UserData, error) {
|
||||
body, err := doGetReq(am.httpClient, location, "")
|
||||
func (am *Auth0Manager) downloadProfileExport(ctx context.Context, location string) (map[string][]*UserData, error) {
|
||||
body, err := doGetReq(ctx, am.httpClient, location, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -927,7 +927,7 @@ func (am *Auth0Manager) downloadProfileExport(location string) (map[string][]*Us
|
||||
}
|
||||
|
||||
// Boilerplate implementation for Get Requests.
|
||||
func doGetReq(client ManagerHTTPClient, url, accessToken string) ([]byte, error) {
|
||||
func doGetReq(ctx context.Context, client ManagerHTTPClient, url, accessToken string) ([]byte, error) {
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -945,7 +945,7 @@ func doGetReq(client ManagerHTTPClient, url, accessToken string) ([]byte, error)
|
||||
defer func() {
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf("error while closing body for url %s: %v", url, err)
|
||||
log.WithContext(ctx).Errorf("error while closing body for url %s: %v", url, err)
|
||||
}
|
||||
}()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -60,7 +61,7 @@ type mockAuth0Credentials struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (mc *mockAuth0Credentials) Authenticate() (JWTToken, error) {
|
||||
func (mc *mockAuth0Credentials) Authenticate(_ context.Context) (JWTToken, error) {
|
||||
return mc.jwtToken, mc.err
|
||||
}
|
||||
|
||||
@@ -126,7 +127,7 @@ func TestAuth0_RequestJWTToken(t *testing.T) {
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
res, err := creds.requestJWTToken()
|
||||
res, err := creds.requestJWTToken(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
@@ -295,7 +296,7 @@ func TestAuth0_Authenticate(t *testing.T) {
|
||||
|
||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||
|
||||
_, err := creds.Authenticate()
|
||||
_, err := creds.Authenticate(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
@@ -417,7 +418,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
|
||||
err := manager.UpdateUserAppMetadata(context.Background(), "1", testCase.appMetadata)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
|
||||
assert.Equal(t, testCase.expectedReqBody, jwtReqClient.reqBody, "request body should match")
|
||||
|
||||
@@ -116,7 +116,7 @@ func (ac *AuthentikCredentials) jwtStillValid() bool {
|
||||
}
|
||||
|
||||
// requestJWTToken performs request to get jwt token.
|
||||
func (ac *AuthentikCredentials) requestJWTToken() (*http.Response, error) {
|
||||
func (ac *AuthentikCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", ac.clientConfig.ClientID)
|
||||
data.Set("username", ac.clientConfig.Username)
|
||||
@@ -131,7 +131,7 @@ func (ac *AuthentikCredentials) requestJWTToken() (*http.Response, error) {
|
||||
}
|
||||
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
||||
|
||||
log.Debug("requesting new jwt token for authentik idp manager")
|
||||
log.WithContext(ctx).Debug("requesting new jwt token for authentik idp manager")
|
||||
|
||||
resp, err := ac.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -183,7 +183,7 @@ func (ac *AuthentikCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the authentik management API.
|
||||
func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) {
|
||||
func (ac *AuthentikCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
||||
ac.mux.Lock()
|
||||
defer ac.mux.Unlock()
|
||||
|
||||
@@ -197,7 +197,7 @@ func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) {
|
||||
return ac.jwtToken, nil
|
||||
}
|
||||
|
||||
resp, err := ac.requestJWTToken()
|
||||
resp, err := ac.requestJWTToken(ctx)
|
||||
if err != nil {
|
||||
return ac.jwtToken, err
|
||||
}
|
||||
@@ -214,13 +214,13 @@ func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) {
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (am *AuthentikManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
|
||||
func (am *AuthentikManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from authentik via ID.
|
||||
func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
ctx, err := am.authenticationContext()
|
||||
func (am *AuthentikManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
ctx, err := am.authenticationContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -254,8 +254,8 @@ func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetada
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
users, err := am.getAllUsers()
|
||||
func (am *AuthentikManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
users, err := am.getAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -274,8 +274,8 @@ func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
users, err := am.getAllUsers()
|
||||
func (am *AuthentikManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
users, err := am.getAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -291,12 +291,12 @@ func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
}
|
||||
|
||||
// getAllUsers returns all users in a Authentik account.
|
||||
func (am *AuthentikManager) getAllUsers() ([]*UserData, error) {
|
||||
func (am *AuthentikManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
|
||||
users := make([]*UserData, 0)
|
||||
|
||||
page := int32(1)
|
||||
for {
|
||||
ctx, err := am.authenticationContext()
|
||||
ctx, err := am.authenticationContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -329,14 +329,14 @@ func (am *AuthentikManager) getAllUsers() ([]*UserData, error) {
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in authentik Idp and sends an invitation.
|
||||
func (am *AuthentikManager) CreateUser(_, _, _, _ string) (*UserData, error) {
|
||||
func (am *AuthentikManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (am *AuthentikManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
ctx, err := am.authenticationContext()
|
||||
func (am *AuthentikManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
ctx, err := am.authenticationContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -368,13 +368,13 @@ func (am *AuthentikManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (am *AuthentikManager) InviteUserByID(_ string) error {
|
||||
func (am *AuthentikManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Authentik
|
||||
func (am *AuthentikManager) DeleteUser(userID string) error {
|
||||
ctx, err := am.authenticationContext()
|
||||
func (am *AuthentikManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
ctx, err := am.authenticationContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -404,8 +404,8 @@ func (am *AuthentikManager) DeleteUser(userID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *AuthentikManager) authenticationContext() (context.Context, error) {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
func (am *AuthentikManager) authenticationContext(ctx context.Context) (context.Context, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
@@ -138,7 +139,7 @@ func TestAuthentikRequestJWTToken(t *testing.T) {
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
resp, err := creds.requestJWTToken()
|
||||
resp, err := creds.requestJWTToken(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
@@ -304,7 +305,7 @@ func TestAuthentikAuthenticate(t *testing.T) {
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||
|
||||
_, err := creds.Authenticate()
|
||||
_, err := creds.Authenticate(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -110,7 +111,7 @@ func (ac *AzureCredentials) jwtStillValid() bool {
|
||||
}
|
||||
|
||||
// requestJWTToken performs request to get jwt token.
|
||||
func (ac *AzureCredentials) requestJWTToken() (*http.Response, error) {
|
||||
func (ac *AzureCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", ac.clientConfig.ClientID)
|
||||
data.Set("client_secret", ac.clientConfig.ClientSecret)
|
||||
@@ -132,7 +133,7 @@ func (ac *AzureCredentials) requestJWTToken() (*http.Response, error) {
|
||||
}
|
||||
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
||||
|
||||
log.Debug("requesting new jwt token for azure idp manager")
|
||||
log.WithContext(ctx).Debug("requesting new jwt token for azure idp manager")
|
||||
|
||||
resp, err := ac.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -184,7 +185,7 @@ func (ac *AzureCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTT
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the azure Management API.
|
||||
func (ac *AzureCredentials) Authenticate() (JWTToken, error) {
|
||||
func (ac *AzureCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
||||
ac.mux.Lock()
|
||||
defer ac.mux.Unlock()
|
||||
|
||||
@@ -198,7 +199,7 @@ func (ac *AzureCredentials) Authenticate() (JWTToken, error) {
|
||||
return ac.jwtToken, nil
|
||||
}
|
||||
|
||||
resp, err := ac.requestJWTToken()
|
||||
resp, err := ac.requestJWTToken(ctx)
|
||||
if err != nil {
|
||||
return ac.jwtToken, err
|
||||
}
|
||||
@@ -215,16 +216,16 @@ func (ac *AzureCredentials) Authenticate() (JWTToken, error) {
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in azure AD Idp.
|
||||
func (am *AzureManager) CreateUser(_, _, _, _ string) (*UserData, error) {
|
||||
func (am *AzureManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from keycloak via ID.
|
||||
func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
func (am *AzureManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
q := url.Values{}
|
||||
q.Add("$select", profileFields)
|
||||
|
||||
body, err := am.get("users/"+userID, q)
|
||||
body, err := am.get(ctx, "users/"+userID, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -247,11 +248,11 @@ func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata)
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
func (am *AzureManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
q := url.Values{}
|
||||
q.Add("$select", profileFields)
|
||||
|
||||
body, err := am.get("users/"+email, q)
|
||||
body, err := am.get(ctx, "users/"+email, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -273,8 +274,8 @@ func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
users, err := am.getAllUsers()
|
||||
func (am *AzureManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
users, err := am.getAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -293,8 +294,8 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
users, err := am.getAllUsers()
|
||||
func (am *AzureManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
users, err := am.getAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -310,19 +311,19 @@ func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID.
|
||||
func (am *AzureManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
|
||||
func (am *AzureManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (am *AzureManager) InviteUserByID(_ string) error {
|
||||
func (am *AzureManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Azure.
|
||||
func (am *AzureManager) DeleteUser(userID string) error {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
func (am *AzureManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -335,7 +336,7 @@ func (am *AzureManager) DeleteUser(userID string) error {
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
log.Debugf("delete idp user %s", userID)
|
||||
log.WithContext(ctx).Debugf("delete idp user %s", userID)
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -358,7 +359,7 @@ func (am *AzureManager) DeleteUser(userID string) error {
|
||||
}
|
||||
|
||||
// getAllUsers returns all users in an Azure AD account.
|
||||
func (am *AzureManager) getAllUsers() ([]*UserData, error) {
|
||||
func (am *AzureManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
|
||||
users := make([]*UserData, 0)
|
||||
|
||||
q := url.Values{}
|
||||
@@ -366,7 +367,7 @@ func (am *AzureManager) getAllUsers() ([]*UserData, error) {
|
||||
q.Add("$top", "500")
|
||||
|
||||
for nextLink := "users"; nextLink != ""; {
|
||||
body, err := am.get(nextLink, q)
|
||||
body, err := am.get(ctx, nextLink, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -391,8 +392,8 @@ func (am *AzureManager) getAllUsers() ([]*UserData, error) {
|
||||
}
|
||||
|
||||
// get perform Get requests.
|
||||
func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
func (am *AzureManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -101,7 +102,7 @@ func TestAzureAuthenticate(t *testing.T) {
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||
|
||||
_, err := creds.Authenticate()
|
||||
_, err := creds.Authenticate(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
|
||||
@@ -39,12 +39,12 @@ type GoogleWorkspaceCredentials struct {
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
func (gc *GoogleWorkspaceCredentials) Authenticate() (JWTToken, error) {
|
||||
func (gc *GoogleWorkspaceCredentials) Authenticate(_ context.Context) (JWTToken, error) {
|
||||
return JWTToken{}, nil
|
||||
}
|
||||
|
||||
// NewGoogleWorkspaceManager creates a new instance of the GoogleWorkspaceManager.
|
||||
func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics telemetry.AppMetrics) (*GoogleWorkspaceManager, error) {
|
||||
func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClientConfig, appMetrics telemetry.AppMetrics) (*GoogleWorkspaceManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
@@ -66,7 +66,7 @@ func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics te
|
||||
}
|
||||
|
||||
// Create a new Admin SDK Directory service client
|
||||
adminCredentials, err := getGoogleCredentials(config.ServiceAccountKey)
|
||||
adminCredentials, err := getGoogleCredentials(ctx, config.ServiceAccountKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -90,12 +90,12 @@ func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics te
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
|
||||
func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from Google Workspace via ID.
|
||||
func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
func (gm *GoogleWorkspaceManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
user, err := gm.usersService.Get(userID).Do()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -112,7 +112,7 @@ func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata App
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
func (gm *GoogleWorkspaceManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
|
||||
users, err := gm.getAllUsers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -132,7 +132,7 @@ func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, err
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (gm *GoogleWorkspaceManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
func (gm *GoogleWorkspaceManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
|
||||
users, err := gm.getAllUsers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -177,13 +177,13 @@ func (gm *GoogleWorkspaceManager) getAllUsers() ([]*UserData, error) {
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in Google Workspace and sends an invitation.
|
||||
func (gm *GoogleWorkspaceManager) CreateUser(_, _, _, _ string) (*UserData, error) {
|
||||
func (gm *GoogleWorkspaceManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
func (gm *GoogleWorkspaceManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
|
||||
user, err := gm.usersService.Get(email).Do()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -201,12 +201,12 @@ func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, err
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (gm *GoogleWorkspaceManager) InviteUserByID(_ string) error {
|
||||
func (gm *GoogleWorkspaceManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from GoogleWorkspace.
|
||||
func (gm *GoogleWorkspaceManager) DeleteUser(userID string) error {
|
||||
func (gm *GoogleWorkspaceManager) DeleteUser(_ context.Context, userID string) error {
|
||||
if err := gm.usersService.Delete(userID).Do(); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -222,8 +222,8 @@ func (gm *GoogleWorkspaceManager) DeleteUser(userID string) error {
|
||||
// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it.
|
||||
// If that fails, it falls back to using the default Google credentials path.
|
||||
// It returns the retrieved credentials or an error if unsuccessful.
|
||||
func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error) {
|
||||
log.Debug("retrieving google credentials from the base64 encoded service account key")
|
||||
func getGoogleCredentials(ctx context.Context, serviceAccountKey string) (*google.Credentials, error) {
|
||||
log.WithContext(ctx).Debug("retrieving google credentials from the base64 encoded service account key")
|
||||
decodeKey, err := base64.StdEncoding.DecodeString(serviceAccountKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode service account key: %w", err)
|
||||
@@ -239,8 +239,8 @@ func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error)
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
log.Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err)
|
||||
log.Debug("falling back to default google credentials location")
|
||||
log.WithContext(ctx).Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err)
|
||||
log.WithContext(ctx).Debug("falling back to default google credentials location")
|
||||
|
||||
creds, err = google.FindDefaultCredentials(
|
||||
context.Background(),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -16,14 +17,14 @@ const (
|
||||
|
||||
// Manager idp manager interface
|
||||
type Manager interface {
|
||||
UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error
|
||||
GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error)
|
||||
GetAccount(accountId string) ([]*UserData, error)
|
||||
GetAllAccounts() (map[string][]*UserData, error)
|
||||
CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error)
|
||||
GetUserByEmail(email string) ([]*UserData, error)
|
||||
InviteUserByID(userID string) error
|
||||
DeleteUser(userID string) error
|
||||
UpdateUserAppMetadata(ctx context.Context, userId string, appMetadata AppMetadata) error
|
||||
GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error)
|
||||
GetAccount(ctx context.Context, accountId string) ([]*UserData, error)
|
||||
GetAllAccounts(ctx context.Context) (map[string][]*UserData, error)
|
||||
CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error)
|
||||
GetUserByEmail(ctx context.Context, email string) ([]*UserData, error)
|
||||
InviteUserByID(ctx context.Context, userID string) error
|
||||
DeleteUser(ctx context.Context, userID string) error
|
||||
}
|
||||
|
||||
// ClientConfig defines common client configuration for all IdP manager
|
||||
@@ -51,7 +52,7 @@ type Config struct {
|
||||
|
||||
// ManagerCredentials interface that authenticates using the credential of each type of idp
|
||||
type ManagerCredentials interface {
|
||||
Authenticate() (JWTToken, error)
|
||||
Authenticate(ctx context.Context) (JWTToken, error)
|
||||
}
|
||||
|
||||
// ManagerHTTPClient http client interface for API calls
|
||||
@@ -91,7 +92,7 @@ type JWTToken struct {
|
||||
}
|
||||
|
||||
// NewManager returns a new idp manager based on the configuration that it receives
|
||||
func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) {
|
||||
func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetrics) (Manager, error) {
|
||||
if config.ClientConfig != nil {
|
||||
config.ClientConfig.Issuer = strings.TrimSuffix(config.ClientConfig.Issuer, "/")
|
||||
}
|
||||
@@ -175,7 +176,7 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
|
||||
ServiceAccountKey: config.ExtraConfig["ServiceAccountKey"],
|
||||
CustomerID: config.ExtraConfig["CustomerId"],
|
||||
}
|
||||
return NewGoogleWorkspaceManager(googleClientConfig, appMetrics)
|
||||
return NewGoogleWorkspaceManager(ctx, googleClientConfig, appMetrics)
|
||||
case "jumpcloud":
|
||||
jumpcloudConfig := JumpCloudClientConfig{
|
||||
APIToken: config.ExtraConfig["ApiToken"],
|
||||
|
||||
@@ -74,7 +74,7 @@ func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppM
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the JumpCloud user API.
|
||||
func (jc *JumpCloudCredentials) Authenticate() (JWTToken, error) {
|
||||
func (jc *JumpCloudCredentials) Authenticate(_ context.Context) (JWTToken, error) {
|
||||
return JWTToken{}, nil
|
||||
}
|
||||
|
||||
@@ -85,12 +85,12 @@ func (jm *JumpCloudManager) authenticationContext() context.Context {
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (jm *JumpCloudManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
|
||||
func (jm *JumpCloudManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from JumpCloud via ID.
|
||||
func (jm *JumpCloudManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
authCtx := jm.authenticationContext()
|
||||
user, resp, err := jm.client.SystemusersApi.SystemusersGet(authCtx, userID, contentType, accept, nil)
|
||||
if err != nil {
|
||||
@@ -116,7 +116,7 @@ func (jm *JumpCloudManager) GetUserDataByID(userID string, appMetadata AppMetada
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (jm *JumpCloudManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
|
||||
authCtx := jm.authenticationContext()
|
||||
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
|
||||
if err != nil {
|
||||
@@ -148,7 +148,7 @@ func (jm *JumpCloudManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (jm *JumpCloudManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
|
||||
authCtx := jm.authenticationContext()
|
||||
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
|
||||
if err != nil {
|
||||
@@ -177,13 +177,13 @@ func (jm *JumpCloudManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in JumpCloud Idp and sends an invitation.
|
||||
func (jm *JumpCloudManager) CreateUser(_, _, _, _ string) (*UserData, error) {
|
||||
func (jm *JumpCloudManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (jm *JumpCloudManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
|
||||
searchFilter := map[string]interface{}{
|
||||
"searchFilter": map[string]interface{}{
|
||||
"filter": []string{email},
|
||||
@@ -219,12 +219,12 @@ func (jm *JumpCloudManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (jm *JumpCloudManager) InviteUserByID(_ string) error {
|
||||
func (jm *JumpCloudManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from jumpCloud directory
|
||||
func (jm *JumpCloudManager) DeleteUser(userID string) error {
|
||||
func (jm *JumpCloudManager) DeleteUser(_ context.Context, userID string) error {
|
||||
authCtx := jm.authenticationContext()
|
||||
_, resp, err := jm.client.SystemusersApi.SystemusersDelete(authCtx, userID, contentType, accept, nil)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -109,7 +110,7 @@ func (kc *KeycloakCredentials) jwtStillValid() bool {
|
||||
}
|
||||
|
||||
// requestJWTToken performs request to get jwt token.
|
||||
func (kc *KeycloakCredentials) requestJWTToken() (*http.Response, error) {
|
||||
func (kc *KeycloakCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", kc.clientConfig.ClientID)
|
||||
data.Set("client_secret", kc.clientConfig.ClientSecret)
|
||||
@@ -122,7 +123,7 @@ func (kc *KeycloakCredentials) requestJWTToken() (*http.Response, error) {
|
||||
}
|
||||
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
||||
|
||||
log.Debug("requesting new jwt token for keycloak idp manager")
|
||||
log.WithContext(ctx).Debug("requesting new jwt token for keycloak idp manager")
|
||||
|
||||
resp, err := kc.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -174,7 +175,7 @@ func (kc *KeycloakCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (J
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the keycloak Management API.
|
||||
func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) {
|
||||
func (kc *KeycloakCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
||||
kc.mux.Lock()
|
||||
defer kc.mux.Unlock()
|
||||
|
||||
@@ -188,7 +189,7 @@ func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) {
|
||||
return kc.jwtToken, nil
|
||||
}
|
||||
|
||||
resp, err := kc.requestJWTToken()
|
||||
resp, err := kc.requestJWTToken(ctx)
|
||||
if err != nil {
|
||||
return kc.jwtToken, err
|
||||
}
|
||||
@@ -205,18 +206,18 @@ func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) {
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in keycloak Idp and sends an invite.
|
||||
func (km *KeycloakManager) CreateUser(_, _, _, _ string) (*UserData, error) {
|
||||
func (km *KeycloakManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (km *KeycloakManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
func (km *KeycloakManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
q := url.Values{}
|
||||
q.Add("email", email)
|
||||
q.Add("exact", "true")
|
||||
|
||||
body, err := km.get("users", q)
|
||||
body, err := km.get(ctx, "users", q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -240,8 +241,8 @@ func (km *KeycloakManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from keycloak via ID.
|
||||
func (km *KeycloakManager) GetUserDataByID(userID string, _ AppMetadata) (*UserData, error) {
|
||||
body, err := km.get("users/"+userID, nil)
|
||||
func (km *KeycloakManager) GetUserDataByID(ctx context.Context, userID string, _ AppMetadata) (*UserData, error) {
|
||||
body, err := km.get(ctx, "users/"+userID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -260,8 +261,8 @@ func (km *KeycloakManager) GetUserDataByID(userID string, _ AppMetadata) (*UserD
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given account profile.
|
||||
func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
profiles, err := km.fetchAllUserProfiles()
|
||||
func (km *KeycloakManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
profiles, err := km.fetchAllUserProfiles(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -283,8 +284,8 @@ func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
profiles, err := km.fetchAllUserProfiles()
|
||||
func (km *KeycloakManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
profiles, err := km.fetchAllUserProfiles(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -303,19 +304,19 @@ func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (km *KeycloakManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
|
||||
func (km *KeycloakManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (km *KeycloakManager) InviteUserByID(_ string) error {
|
||||
func (km *KeycloakManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Keycloak by user ID.
|
||||
func (km *KeycloakManager) DeleteUser(userID string) error {
|
||||
jwtToken, err := km.credentials.Authenticate()
|
||||
func (km *KeycloakManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
jwtToken, err := km.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -353,8 +354,8 @@ func (km *KeycloakManager) DeleteUser(userID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) {
|
||||
totalUsers, err := km.totalUsersCount()
|
||||
func (km *KeycloakManager) fetchAllUserProfiles(ctx context.Context) ([]keycloakProfile, error) {
|
||||
totalUsers, err := km.totalUsersCount(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -362,7 +363,7 @@ func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) {
|
||||
q := url.Values{}
|
||||
q.Add("max", fmt.Sprint(*totalUsers))
|
||||
|
||||
body, err := km.get("users", q)
|
||||
body, err := km.get(ctx, "users", q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -377,8 +378,8 @@ func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) {
|
||||
}
|
||||
|
||||
// get perform Get requests.
|
||||
func (km *KeycloakManager) get(resource string, q url.Values) ([]byte, error) {
|
||||
jwtToken, err := km.credentials.Authenticate()
|
||||
func (km *KeycloakManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
|
||||
jwtToken, err := km.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -414,8 +415,8 @@ func (km *KeycloakManager) get(resource string, q url.Values) ([]byte, error) {
|
||||
|
||||
// totalUsersCount returns the total count of all user created.
|
||||
// Used when fetching all registered accounts with pagination.
|
||||
func (km *KeycloakManager) totalUsersCount() (*int, error) {
|
||||
body, err := km.get("users/count", nil)
|
||||
func (km *KeycloakManager) totalUsersCount(ctx context.Context) (*int, error) {
|
||||
body, err := km.get(ctx, "users/count", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
@@ -128,7 +129,7 @@ func TestKeycloakRequestJWTToken(t *testing.T) {
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
resp, err := creds.requestJWTToken()
|
||||
resp, err := creds.requestJWTToken(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
@@ -294,7 +295,7 @@ func TestKeycloakAuthenticate(t *testing.T) {
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||
|
||||
_, err := creds.Authenticate()
|
||||
_, err := creds.Authenticate(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
|
||||
@@ -1,77 +1,79 @@
|
||||
package idp
|
||||
|
||||
import "context"
|
||||
|
||||
// MockIDP is a mock implementation of the IDP interface
|
||||
type MockIDP struct {
|
||||
UpdateUserAppMetadataFunc func(userId string, appMetadata AppMetadata) error
|
||||
GetUserDataByIDFunc func(userId string, appMetadata AppMetadata) (*UserData, error)
|
||||
GetAccountFunc func(accountId string) ([]*UserData, error)
|
||||
GetAllAccountsFunc func() (map[string][]*UserData, error)
|
||||
CreateUserFunc func(email, name, accountID, invitedByEmail string) (*UserData, error)
|
||||
GetUserByEmailFunc func(email string) ([]*UserData, error)
|
||||
InviteUserByIDFunc func(userID string) error
|
||||
DeleteUserFunc func(userID string) error
|
||||
UpdateUserAppMetadataFunc func(ctx context.Context, userId string, appMetadata AppMetadata) error
|
||||
GetUserDataByIDFunc func(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error)
|
||||
GetAccountFunc func(ctx context.Context, accountId string) ([]*UserData, error)
|
||||
GetAllAccountsFunc func(ctx context.Context) (map[string][]*UserData, error)
|
||||
CreateUserFunc func(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error)
|
||||
GetUserByEmailFunc func(ctx context.Context, email string) ([]*UserData, error)
|
||||
InviteUserByIDFunc func(ctx context.Context, userID string) error
|
||||
DeleteUserFunc func(ctx context.Context, userID string) error
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata is a mock implementation of the IDP interface UpdateUserAppMetadata method
|
||||
func (m *MockIDP) UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error {
|
||||
func (m *MockIDP) UpdateUserAppMetadata(ctx context.Context, userId string, appMetadata AppMetadata) error {
|
||||
if m.UpdateUserAppMetadataFunc != nil {
|
||||
return m.UpdateUserAppMetadataFunc(userId, appMetadata)
|
||||
return m.UpdateUserAppMetadataFunc(ctx, userId, appMetadata)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID is a mock implementation of the IDP interface GetUserDataByID method
|
||||
func (m *MockIDP) GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error) {
|
||||
func (m *MockIDP) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) {
|
||||
if m.GetUserDataByIDFunc != nil {
|
||||
return m.GetUserDataByIDFunc(userId, appMetadata)
|
||||
return m.GetUserDataByIDFunc(ctx, userId, appMetadata)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetAccount is a mock implementation of the IDP interface GetAccount method
|
||||
func (m *MockIDP) GetAccount(accountId string) ([]*UserData, error) {
|
||||
func (m *MockIDP) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) {
|
||||
if m.GetAccountFunc != nil {
|
||||
return m.GetAccountFunc(accountId)
|
||||
return m.GetAccountFunc(ctx, accountId)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts is a mock implementation of the IDP interface GetAllAccounts method
|
||||
func (m *MockIDP) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
func (m *MockIDP) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
if m.GetAllAccountsFunc != nil {
|
||||
return m.GetAllAccountsFunc()
|
||||
return m.GetAllAccountsFunc(ctx)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// CreateUser is a mock implementation of the IDP interface CreateUser method
|
||||
func (m *MockIDP) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
func (m *MockIDP) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
if m.CreateUserFunc != nil {
|
||||
return m.CreateUserFunc(email, name, accountID, invitedByEmail)
|
||||
return m.CreateUserFunc(ctx, email, name, accountID, invitedByEmail)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetUserByEmail is a mock implementation of the IDP interface GetUserByEmail method
|
||||
func (m *MockIDP) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
func (m *MockIDP) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
if m.GetUserByEmailFunc != nil {
|
||||
return m.GetUserByEmailFunc(email)
|
||||
return m.GetUserByEmailFunc(ctx, email)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// InviteUserByID is a mock implementation of the IDP interface InviteUserByID method
|
||||
func (m *MockIDP) InviteUserByID(userID string) error {
|
||||
func (m *MockIDP) InviteUserByID(ctx context.Context, userID string) error {
|
||||
if m.InviteUserByIDFunc != nil {
|
||||
return m.InviteUserByIDFunc(userID)
|
||||
return m.InviteUserByIDFunc(ctx, userID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteUser is a mock implementation of the IDP interface DeleteUser method
|
||||
func (m *MockIDP) DeleteUser(userID string) error {
|
||||
func (m *MockIDP) DeleteUser(ctx context.Context, userID string) error {
|
||||
if m.DeleteUserFunc != nil {
|
||||
return m.DeleteUserFunc(userID)
|
||||
return m.DeleteUserFunc(ctx, userID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -94,17 +94,17 @@ func NewOktaManager(config OktaClientConfig, appMetrics telemetry.AppMetrics) (*
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the okta user API.
|
||||
func (oc *OktaCredentials) Authenticate() (JWTToken, error) {
|
||||
func (oc *OktaCredentials) Authenticate(_ context.Context) (JWTToken, error) {
|
||||
return JWTToken{}, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in okta Idp and sends an invitation.
|
||||
func (om *OktaManager) CreateUser(_, _, _, _ string) (*UserData, error) {
|
||||
func (om *OktaManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from keycloak via ID.
|
||||
func (om *OktaManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
func (om *OktaManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
user, resp, err := om.client.User.GetUser(context.Background(), userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -132,7 +132,7 @@ func (om *OktaManager) GetUserDataByID(userID string, appMetadata AppMetadata) (
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (om *OktaManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
func (om *OktaManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
|
||||
user, resp, err := om.client.User.GetUser(context.Background(), url.QueryEscape(email))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -160,7 +160,7 @@ func (om *OktaManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
func (om *OktaManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
|
||||
users, err := om.getAllUsers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -180,7 +180,7 @@ func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
func (om *OktaManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
|
||||
users, err := om.getAllUsers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -242,18 +242,18 @@ func (om *OktaManager) getAllUsers() ([]*UserData, error) {
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (om *OktaManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
|
||||
func (om *OktaManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (om *OktaManager) InviteUserByID(_ string) error {
|
||||
func (om *OktaManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Okta
|
||||
func (om *OktaManager) DeleteUser(userID string) error {
|
||||
func (om *OktaManager) DeleteUser(_ context.Context, userID string) error {
|
||||
resp, err := om.client.User.DeactivateOrDeleteUser(context.Background(), userID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -149,7 +150,7 @@ func (zc *ZitadelCredentials) jwtStillValid() bool {
|
||||
}
|
||||
|
||||
// requestJWTToken performs request to get jwt token.
|
||||
func (zc *ZitadelCredentials) requestJWTToken() (*http.Response, error) {
|
||||
func (zc *ZitadelCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", zc.clientConfig.ClientID)
|
||||
data.Set("client_secret", zc.clientConfig.ClientSecret)
|
||||
@@ -163,7 +164,7 @@ func (zc *ZitadelCredentials) requestJWTToken() (*http.Response, error) {
|
||||
}
|
||||
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
||||
|
||||
log.Debug("requesting new jwt token for zitadel idp manager")
|
||||
log.WithContext(ctx).Debug("requesting new jwt token for zitadel idp manager")
|
||||
|
||||
resp, err := zc.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -215,7 +216,7 @@ func (zc *ZitadelCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JW
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the Zitadel Management API.
|
||||
func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) {
|
||||
func (zc *ZitadelCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
||||
zc.mux.Lock()
|
||||
defer zc.mux.Unlock()
|
||||
|
||||
@@ -229,7 +230,7 @@ func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) {
|
||||
return zc.jwtToken, nil
|
||||
}
|
||||
|
||||
resp, err := zc.requestJWTToken()
|
||||
resp, err := zc.requestJWTToken(ctx)
|
||||
if err != nil {
|
||||
return zc.jwtToken, err
|
||||
}
|
||||
@@ -246,7 +247,7 @@ func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) {
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in zitadel Idp and sends an invite via Zitadel.
|
||||
func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
firstLast := strings.SplitN(name, " ", 2)
|
||||
|
||||
var addUser = map[string]any{
|
||||
@@ -269,7 +270,7 @@ func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail stri
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err := zm.post("users/human/_import", string(payload))
|
||||
body, err := zm.post(ctx, "users/human/_import", string(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -300,7 +301,7 @@ func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail stri
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
func (zm *ZitadelManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
searchByEmail := zitadelAttributes{
|
||||
"queries": {
|
||||
{
|
||||
@@ -316,7 +317,7 @@ func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err := zm.post("users/_search", string(payload))
|
||||
body, err := zm.post(ctx, "users/_search", string(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -340,8 +341,8 @@ func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from zitadel via ID.
|
||||
func (zm *ZitadelManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
body, err := zm.get("users/"+userID, nil)
|
||||
func (zm *ZitadelManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
body, err := zm.get(ctx, "users/"+userID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -363,8 +364,8 @@ func (zm *ZitadelManager) GetUserDataByID(userID string, appMetadata AppMetadata
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
body, err := zm.post("users/_search", "")
|
||||
func (zm *ZitadelManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
body, err := zm.post(ctx, "users/_search", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -392,8 +393,8 @@ func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
body, err := zm.post("users/_search", "")
|
||||
func (zm *ZitadelManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
body, err := zm.post(ctx, "users/_search", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -419,7 +420,7 @@ func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
// Metadata values are base64 encoded.
|
||||
func (zm *ZitadelManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
|
||||
func (zm *ZitadelManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -429,7 +430,7 @@ type inviteUserRequest struct {
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (zm *ZitadelManager) InviteUserByID(userID string) error {
|
||||
func (zm *ZitadelManager) InviteUserByID(ctx context.Context, userID string) error {
|
||||
inviteUser := inviteUserRequest{
|
||||
Email: userID,
|
||||
}
|
||||
@@ -440,14 +441,14 @@ func (zm *ZitadelManager) InviteUserByID(userID string) error {
|
||||
}
|
||||
|
||||
// don't care about the body in the response
|
||||
_, err = zm.post(fmt.Sprintf("users/%s/_resend_initialization", userID), string(payload))
|
||||
_, err = zm.post(ctx, fmt.Sprintf("users/%s/_resend_initialization", userID), string(payload))
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteUser from Zitadel
|
||||
func (zm *ZitadelManager) DeleteUser(userID string) error {
|
||||
func (zm *ZitadelManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
resource := fmt.Sprintf("users/%s", userID)
|
||||
if err := zm.delete(resource); err != nil {
|
||||
if err := zm.delete(ctx, resource); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -459,8 +460,8 @@ func (zm *ZitadelManager) DeleteUser(userID string) error {
|
||||
}
|
||||
|
||||
// post perform Post requests.
|
||||
func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) {
|
||||
jwtToken, err := zm.credentials.Authenticate()
|
||||
func (zm *ZitadelManager) post(ctx context.Context, resource string, body string) ([]byte, error) {
|
||||
jwtToken, err := zm.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -495,8 +496,8 @@ func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) {
|
||||
}
|
||||
|
||||
// delete perform Delete requests.
|
||||
func (zm *ZitadelManager) delete(resource string) error {
|
||||
jwtToken, err := zm.credentials.Authenticate()
|
||||
func (zm *ZitadelManager) delete(ctx context.Context, resource string) error {
|
||||
jwtToken, err := zm.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -531,8 +532,8 @@ func (zm *ZitadelManager) delete(resource string) error {
|
||||
}
|
||||
|
||||
// get perform Get requests.
|
||||
func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) {
|
||||
jwtToken, err := zm.credentials.Authenticate()
|
||||
func (zm *ZitadelManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
|
||||
jwtToken, err := zm.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
@@ -108,7 +109,7 @@ func TestZitadelRequestJWTToken(t *testing.T) {
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
resp, err := creds.requestJWTToken()
|
||||
resp, err := creds.requestJWTToken(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
@@ -274,7 +275,7 @@ func TestZitadelAuthenticate(t *testing.T) {
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||
|
||||
_, err := creds.Authenticate()
|
||||
_, err := creds.Authenticate(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/google/martian/v3/log"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
)
|
||||
@@ -19,22 +20,22 @@ import (
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if any occurred during the process, otherwise returns nil
|
||||
func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error {
|
||||
ok, err := am.GroupValidation(accountID, groups)
|
||||
func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error {
|
||||
ok, err := am.GroupValidation(ctx, accountID, groups)
|
||||
if err != nil {
|
||||
log.Debugf("error validating groups: %s", err.Error())
|
||||
log.WithContext(ctx).Debugf("error validating groups: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if !ok {
|
||||
log.Debugf("invalid groups")
|
||||
log.WithContext(ctx).Debugf("invalid groups")
|
||||
return errors.New("invalid groups")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
a, err := am.Store.GetAccountByUser(userID)
|
||||
a, err := am.Store.GetAccountByUser(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -48,14 +49,14 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID strin
|
||||
a.Settings.Extra = extra
|
||||
}
|
||||
extra.IntegratedValidatorGroups = groups
|
||||
return am.Store.SaveAccount(a)
|
||||
return am.Store.SaveAccount(ctx, a)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GroupValidation(accountId string, groups []string) (bool, error) {
|
||||
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) {
|
||||
if len(groups) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
accountsGroups, err := am.ListGroups(accountId)
|
||||
accountsGroups, err := am.ListGroups(ctx, accountId)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package integrated_validator
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
@@ -8,12 +10,12 @@ import (
|
||||
|
||||
// IntegratedValidator interface exists to avoid the circle dependencies
|
||||
type IntegratedValidator interface {
|
||||
ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
|
||||
ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error)
|
||||
PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer
|
||||
IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error)
|
||||
ValidateExtraSettings(ctx context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
|
||||
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error)
|
||||
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer
|
||||
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error)
|
||||
GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error)
|
||||
PeerDeleted(accountID, peerID string) error
|
||||
PeerDeleted(ctx context.Context, accountID, peerID string) error
|
||||
SetPeerInvalidationListener(fn func(accountID string))
|
||||
Stop()
|
||||
Stop(ctx context.Context)
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user