Compare commits

..

19 Commits

Author SHA1 Message Date
braginini
f07671cf49 Fix RemoveConnByUfrag 2022-09-08 19:39:05 +02:00
braginini
5a504ee6be Fix some iface issues 2022-09-07 21:59:01 +02:00
braginini
660b2542d2 Remove unused code 2022-09-07 21:40:45 +02:00
braginini
d0ad53b247 Remove unnecessary endpoint map 2022-09-07 19:40:34 +02:00
braginini
2cffe6526a Add more logging 2022-09-07 19:02:17 +02:00
braginini
dded91235e Refactor UDP mux to handle STUN only messages 2022-09-07 18:49:15 +02:00
braginini
314f34f916 Single Mux 2022-09-07 18:40:42 +02:00
braginini
eaf985624d Single Mux 2022-09-07 18:39:58 +02:00
braginini
48b7c6ec3c Fix TURN issue 2022-09-07 11:17:54 +02:00
braginini
acf271bf25 Merge remote-tracking branch 'origin/main' into feature/interface-bind 2022-09-07 11:09:45 +02:00
braginini
f49c299d77 Check for stun packet with a fixed size 2022-09-06 21:07:21 +02:00
braginini
73b5f8d63b Proper endpoint log 2022-09-06 20:59:19 +02:00
braginini
6653894691 Remove unused code 2022-09-06 20:55:51 +02:00
braginini
a7facc2d72 Split UDPMux and UniversalUDPMux 2022-09-06 20:54:40 +02:00
braginini
0721b87c56 Split UDPMux and UniversalUDPMux 2022-09-06 20:44:49 +02:00
braginini
2829cce644 Implement ICEBind 2022-09-06 20:06:51 +02:00
braginini
9350c5f8d8 bind 2022-09-05 15:56:36 +02:00
braginini
2ae4c204af Working single channel bind 2022-09-05 02:03:16 +02:00
braginini
f5e974c04c Bind test 2022-09-04 22:52:52 +02:00
113 changed files with 2242 additions and 8603 deletions

View File

@@ -1,10 +1,5 @@
name: Test Code Darwin
on:
push:
branches:
- main
pull_request:
on: [push,pull_request]
jobs:
test:

View File

@@ -1,10 +1,5 @@
name: Test Code Linux
on:
push:
branches:
- main
pull_request:
on: [push,pull_request]
jobs:
test:
@@ -38,55 +33,3 @@ jobs:
- name: Test
run: GOARCH=${{ matrix.arch }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
test_client_on_docker:
runs-on: ubuntu-latest
steps:
- name: Install Go
uses: actions/setup-go@v2
with:
go-version: 1.18.x
- name: Cache Go modules
uses: actions/cache@v2
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v2
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libappindicator3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
- name: Install modules
run: go mod tidy
- name: Generate Iface Test bin
run: go test -c -o iface-testing.bin ./iface/...
- name: Generate RouteManager Test bin
run: go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
- name: Generate Engine Test bin
run: go test -c -o engine-testing.bin ./client/internal/*.go
- name: Generate Peer Test bin
run: go test -c -o peer-testing.bin ./client/internal/peer/...
- run: chmod +x *testing.bin
- name: Run Iface tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1
- name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Engine tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Peer tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -1,10 +1,5 @@
name: Test Code Windows
on:
push:
branches:
- main
pull_request:
on: [push,pull_request]
jobs:
pre:
@@ -25,6 +20,7 @@ jobs:
needs: pre
runs-on: windows-latest
steps:
- name: Checkout code
uses: actions/checkout@v2

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.4"
SIGN_PIPE_VER: "v0.0.3"
GORELEASER_VER: "v1.6.3"
jobs:

View File

@@ -1,10 +1,5 @@
name: Test Docker Compose Linux
on:
push:
branches:
- main
pull_request:
on: [push,pull_request]
jobs:
test:
@@ -56,16 +51,13 @@ jobs:
CI_NETBIRD_AUTH_JWT_CERTS: https://example.eu.auth0.com/.well-known/jwks.json
CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token
CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code
CI_NETBIRD_AUTH_REDIRECT_URI: "/peers"
run: |
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
grep AUTH_AUTHORITY docker-compose.yml | grep $CI_NETBIRD_AUTH_AUTHORITY
grep AUTH_AUDIENCE docker-compose.yml | grep $CI_NETBIRD_AUTH_AUDIENCE
grep AUTH_SUPPORTED_SCOPES docker-compose.yml | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
grep USE_AUTH0 docker-compose.yml | grep $CI_NETBIRD_USE_AUTH0
grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "http://localhost:33073"
grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI
grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$'
grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "http://localhost:33073"
- name: run docker compose up
working-directory: infrastructure_files

View File

@@ -41,7 +41,7 @@ builds:
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/client/system.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- -s -w -X main.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}'
- id: netbird-signal

View File

@@ -1,6 +1,6 @@
<p align="center">
<strong>:hatching_chick: New Release! User Invites.</strong>
<a href="https://github.com/netbirdio/netbird/releases">
<strong>:hatching_chick: New release! NetBird Easy SSH</strong>.
<a href="https://github.com/netbirdio/netbird/releases/tag/v0.8.0">
Learn more
</a>
</p>
@@ -16,7 +16,7 @@
<a href="https://www.codacy.com/gh/netbirdio/netbird/dashboard?utm_source=github.com&amp;utm_medium=referral&amp;utm_content=netbirdio/netbird&amp;utm_campaign=Badge_Grade"><img src="https://app.codacy.com/project/badge/Grade/e3013d046aec44cdb7462c8673b00976"/></a>
<br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
<img src="https://img.shields.io/badge/slack-@wiretrustee-red.svg?logo=slack"/>
</a>
</p>
</div>
@@ -43,27 +43,29 @@ It requires zero configuration effort leaving behind the hassle of opening ports
NetBird creates an overlay peer-to-peer network connecting machines automatically regardless of their location (home, office, datacenter, container, cloud or edge environments) unifying virtual private network management experience.
**Key features:**
- \[x] Automatic IP allocation and network management with a Web UI ([separate repo](https://github.com/netbirdio/dashboard))
- \[x] Automatic WireGuard peer (machine) discovery and configuration.
- \[x] Encrypted peer-to-peer connections without a central VPN gateway.
- \[x] Connection relay fallback in case a peer-to-peer connection is not possible.
- \[x] Desktop client applications for Linux, MacOS, and Windows (systray).
- \[x] Multiuser support - sharing network between multiple users.
- \[x] SSO and MFA support.
- \[x] Multicloud and hybrid-cloud support.
- \[x] Kernel WireGuard usage when possible.
- \[x] Access Controls - groups & rules.
- \[x] Remote SSH access without managing SSH keys.
- \[x] Network Routes.
- \[x] Automatic IP allocation and network management with a Web UI ([separate repo](https://github.com/netbirdio/dashboard))
- \[x] Automatic WireGuard peer (machine) discovery and configuration.
- \[x] Encrypted peer-to-peer connections without a central VPN gateway.
- \[x] Connection relay fallback in case a peer-to-peer connection is not possible.
- \[x] Desktop client applications for Linux, MacOS, and Windows (systray).
- \[x] Multiuser support - sharing network between multiple users.
- \[x] SSO and MFA support.
- \[x] Multicloud and hybrid-cloud support.
- \[x] Kernel WireGuard usage when possible.
- \[x] Access Controls - groups & rules.
- \[x] Remote SSH access without managing SSH keys.
**Coming soon:**
- \[ ] Network Routes.
- \[ ] Private DNS.
- \[ ] Mobile clients.
- \[ ] Network Activity Monitoring.
### Secure peer-to-peer VPN with SSO and MFA in minutes
https://user-images.githubusercontent.com/700848/197345890-2e2cded5-7b7a-436f-a444-94e80dd24f46.mov
<p float="left" align="middle">
<img src="docs/media/peerA.gif" width="400"/>
<img src="docs/media/peerB.gif" width="400"/>
</p>
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
@@ -103,5 +105,5 @@ See a complete [architecture overview](https://netbird.io/docs/overview/architec
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), and [Coturn](https://github.com/coturn/coturn). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution).
### Legal
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.
[WireGuard](https://wireguard.com/) is a registered trademark of Jason A. Donenfeld.

View File

@@ -11,7 +11,6 @@ import (
"github.com/netbirdio/netbird/util"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"net"
"net/netip"
"sort"
"strings"
@@ -19,7 +18,6 @@ import (
var (
detailFlag bool
ipv4Flag bool
ipsFilter []string
statusFilter string
ipsFilterMap map[string]struct{}
@@ -75,7 +73,7 @@ var statusCmd = &cobra.Command{
pbFullStatus := resp.GetFullStatus()
fullStatus := fromProtoFullStatus(pbFullStatus)
cmd.Print(parseFullStatus(fullStatus, detailFlag, daemonStatus, resp.GetDaemonVersion(), ipv4Flag))
cmd.Print(parseFullStatus(fullStatus, detailFlag, daemonStatus, resp.GetDaemonVersion()))
return nil
},
@@ -84,9 +82,8 @@ var statusCmd = &cobra.Command{
func init() {
ipsFilterMap = make(map[string]struct{})
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information")
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g. --filter-by-ips 100.64.0.100,100.64.0.200")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g. --filter-by-status connected")
}
func parseFilters() error {
@@ -145,19 +142,7 @@ func fromProtoFullStatus(pbFullStatus *proto.FullStatus) nbStatus.FullStatus {
return fullStatus
}
func parseFullStatus(fullStatus nbStatus.FullStatus, printDetail bool, daemonStatus string, daemonVersion string, flag bool) string {
interfaceIP := fullStatus.LocalPeerState.IP
ip, _, err := net.ParseCIDR(interfaceIP)
if err != nil {
return ""
}
if ipv4Flag {
return fmt.Sprintf("%s\n", ip)
}
func parseFullStatus(fullStatus nbStatus.FullStatus, printDetail bool, daemonStatus string, daemonVersion string) string {
var (
managementStatusURL = ""
signalStatusURL = ""
@@ -179,6 +164,8 @@ func parseFullStatus(fullStatus nbStatus.FullStatus, printDetail bool, daemonSta
signalConnString = "Connected"
}
interfaceIP := fullStatus.LocalPeerState.IP
if fullStatus.LocalPeerState.KernelInterface {
interfaceTypeString = "Kernel"
} else if fullStatus.LocalPeerState.IP == "" {

View File

@@ -68,12 +68,12 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
}
peersUpdateManager := mgmt.NewPeersUpdateManager()
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "")
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil)
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager)
if err != nil {
t.Fatal(err)
}

98
client/hhhh.go Normal file
View File

@@ -0,0 +1,98 @@
package main
/*
import (
"flag"
"github.com/netbirdio/netbird/iface"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"net"
"net/http"
_ "net/http/pprof"
"time"
)
var name = flag.String("name", "wg0", "WireGuard interface name")
var addr = flag.String("addr", "100.64.0.1/24", "interface WireGuard IP addr")
var key = flag.String("key", "100.64.0.1/24", "WireGuard private key")
var port = flag.Int("port", 51820, "WireGuard port")
var remoteKey = flag.String("remote-key", "", "remote WireGuard public key")
var remoteAddr = flag.String("remote-addr", "100.64.0.2/32", "remote WireGuard IP addr")
var remoteEndpoint = flag.String("remote-endpoint", "127.0.0.1:51820", "remote WireGuard endpoint")
func fff() {
flag.Parse()
go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
}()
myKey, err := wgtypes.ParseKey(*key)
if err != nil {
log.Error(err)
return
}
log.Infof("public key and addr [%s] [%s] ", myKey.PublicKey().String(), *addr)
wgIFace, err := iface.NewWGIFace(*name, *addr, 1280)
if err != nil {
log.Error(err)
return
}
defer wgIFace.Close()
// todo wrap into UDPMux
sharedSock, _, err := listenNet("udp4", *port)
if err != nil {
log.Error(err)
return
}
defer sharedSock.Close()
// err = wgIFace.Create()
err = wgIFace.CreateNew(sharedSock)
if err != nil {
log.Errorf("failed to create interface %s %v", *name, err)
return
}
err = wgIFace.Configure(*key, *port)
if err != nil {
log.Errorf("failed to configure interface %s %v", *name, err)
return
}
ip, err := net.ResolveUDPAddr("udp4", *remoteEndpoint)
if err != nil {
// handle error
}
err = wgIFace.UpdatePeer(*remoteKey, *remoteAddr, 20*time.Second, ip, nil)
if err != nil {
log.Errorf("failed to configure remote peer %s %v", *remoteKey, err)
return
}
select {}
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
if err != nil {
return nil, 0, err
}
// Retrieve port.
laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr(
laddr.Network(),
laddr.String(),
)
if err != nil {
return nil, 0, err
}
return conn, uaddr.Port, nil
}*/

View File

@@ -101,7 +101,6 @@ done:
Pop $2
Exch $1
FunctionEnd
!macro GetAppFromCommand in out
Push "${in}"
Call GetAppFromCommand
@@ -118,7 +117,7 @@ Call GetAppFromCommand ; Remove quotes and parameters from UninstCommand
Pop $0
Pop $1
GetFullPathName $2 "$0\.."
ExecWait '"$0" /S $1 _?=$2'
ExecWait '"$0" $1 _?=$2'
Delete "$0" ; Extra cleanup because we used _?=
RMDir "$2"
Pop $2
@@ -127,27 +126,30 @@ Pop $0
!macroend
Function .onInit
StrCpy $INSTDIR "${INSTALL_DIR}"
ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\Wiretrustee" "UninstallString"
${If} $R0 != ""
MessageBox MB_YESNO|MB_ICONQUESTION "Wiretrustee is installed. We must remove it before installing Netbird. Procced?" IDNO noWTUninstOld
!insertmacro UninstallPreviousNSIS $R0 "/NoMsgBox"
noWTUninstOld:
${EndIf}
ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString"
${If} $R0 != ""
# if silent install jump to uninstall step
IfSilent uninstall
MessageBox MB_YESNO|MB_ICONQUESTION "NetBird is already installed. We must remove it before installing upgrading NetBird. Proceed?" IDNO done IDYES uninstall
uninstall:
!insertmacro UninstallPreviousNSIS $R0 "/NoMsgBox"
done:
MessageBox MB_YESNO|MB_ICONQUESTION "$(^NAME) is already installed. Do you want to remove the previous version?" IDNO noUninstOld
!insertmacro UninstallPreviousNSIS $R0 "/NoMsgBox"
noUninstOld:
${EndIf}
FunctionEnd
######################################################################
Section -MainProgram
${INSTALL_TYPE}
# SetOverwrite ifnewer
SetOverwrite ifnewer
SetOutPath "$INSTDIR"
File /r "..\\dist\\netbird_windows_amd64\\"
SectionEnd
######################################################################
Section -Icons_Reg
@@ -170,29 +172,24 @@ SetShellVarContext current
CreateShortCut "$SMPROGRAMS\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
CreateShortCut "$DESKTOP\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
SetShellVarContext all
SectionEnd
Section -Post
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service install'
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service start'
Exec '"$INSTDIR\${MAIN_APP_EXE}" service start'
# sleep a bit for visibility
Sleep 1000
SectionEnd
######################################################################
Section Uninstall
${INSTALL_TYPE}
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
Exec '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
# kill ui client
ExecWait `taskkill /im ${UI_APP_EXE}.exe`
# wait the service uninstall take unblock the executable
Sleep 3000
Delete "$INSTDIR\${UI_APP_EXE}"
Delete "$INSTDIR\${MAIN_APP_EXE}"
RmDir /r "$INSTDIR"
SetShellVarContext current
@@ -212,4 +209,4 @@ SetShellVarContext current
SetOutPath $INSTDIR
ShellExecAsUser::ShellExecAsUser "" "$DESKTOP\${APP_NAME}.lnk"
SetShellVarContext all
FunctionEnd
FunctionEnd

View File

@@ -263,7 +263,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, config *Config) (Device
}
}
deviceAuthorizationFlow := DeviceAuthorizationFlow{
return DeviceAuthorizationFlow{
Provider: protoDeviceAuthorizationFlow.Provider.String(),
ProviderConfig: ProviderConfig{
@@ -274,29 +274,5 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, config *Config) (Device
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
},
}
err = isProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
if err != nil {
return DeviceAuthorizationFlow{}, err
}
return deviceAuthorizationFlow, nil
}
func isProviderConfigValid(config ProviderConfig) error {
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.Audience == "" {
return fmt.Errorf(errorMSGFormat, "Audience")
}
if config.ClientID == "" {
return fmt.Errorf(errorMSGFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
}
if config.DeviceAuthEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
}
return nil
}, nil
}

View File

@@ -107,7 +107,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
localPeerState := nbStatus.LocalPeerState{
IP: loginResp.GetPeerConfig().GetAddress(),
PubKey: myPrivateKey.PublicKey().String(),
KernelInterface: iface.WireguardModuleIsLoaded(),
KernelInterface: iface.WireguardModExists(),
}
statusRecorder.UpdateLocalPeerState(localPeerState)

View File

@@ -1,56 +0,0 @@
package dns
import (
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
log "github.com/sirupsen/logrus"
"sync"
)
type localResolver struct {
registeredMap registrationMap
records sync.Map
}
// ServeDNS handles a DNS request
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
log.Tracef("received question: %#v\n", r.Question[0])
response := d.lookupRecord(r)
if response == nil {
log.Debugf("got empty response for question: %#v\n", r.Question[0])
return
}
replyMessage := &dns.Msg{}
replyMessage.SetReply(r)
replyMessage.Answer = append(replyMessage.Answer, response)
err := w.WriteMsg(replyMessage)
if err != nil {
log.Debugf("got an error while writing the local resolver response, error: %v", err)
}
}
func (d *localResolver) lookupRecord(r *dns.Msg) dns.RR {
record, found := d.records.Load(r.Question[0].Name)
if !found {
return nil
}
return record.(dns.RR)
}
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) error {
fullRecord, err := dns.NewRR(record.String())
if err != nil {
return err
}
d.records.Store(fullRecord.Header().Name, fullRecord)
return nil
}
func (d *localResolver) deleteRecord(recordKey string) {
d.records.Delete(dns.Fqdn(recordKey))
}

View File

@@ -1,86 +0,0 @@
package dns
import (
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
"strings"
"testing"
)
func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "1.2.3.4",
}
recordCNAME := nbdns.SimpleRecord{
Name: "peerb.netbird.cloud.",
Type: 5,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "www.netbird.io",
}
testCases := []struct {
name string
inputRecord nbdns.SimpleRecord
inputMSG *dns.Msg
responseShouldBeNil bool
}{
{
name: "Should Resolve A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA),
},
{
name: "Should Resolve CNAME Record",
inputRecord: recordCNAME,
inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME),
},
{
name: "Should Not Write When Not Found A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
responseShouldBeNil: true,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
resolver := &localResolver{
registeredMap: make(registrationMap),
}
_ = resolver.registerRecord(testCase.inputRecord)
var responseMSG *dns.Msg
responseWriter := &mockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, testCase.inputMSG)
if responseMSG == nil {
if testCase.responseShouldBeNil {
return
}
t.Fatalf("should write a response message")
}
answerString := responseMSG.Answer[0].String()
if !strings.Contains(answerString, testCase.inputRecord.Name) {
t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString)
}
if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) {
t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString)
}
if !strings.Contains(answerString, testCase.inputRecord.RData) {
t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString)
}
})
}
}

View File

@@ -1,25 +0,0 @@
package dns
import (
"github.com/miekg/dns"
"net"
)
type mockResponseWriter struct {
WriteMsgFunc func(m *dns.Msg) error
}
func (rw *mockResponseWriter) WriteMsg(m *dns.Msg) error {
if rw.WriteMsgFunc != nil {
return rw.WriteMsgFunc(m)
}
return nil
}
func (rw *mockResponseWriter) LocalAddr() net.Addr { return nil }
func (rw *mockResponseWriter) RemoteAddr() net.Addr { return nil }
func (rw *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (rw *mockResponseWriter) Close() error { return nil }
func (rw *mockResponseWriter) TsigStatus() error { return nil }
func (rw *mockResponseWriter) TsigTimersOnly(bool) {}
func (rw *mockResponseWriter) Hijack() {}

View File

@@ -1,270 +0,0 @@
package dns
import (
"context"
"fmt"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
log "github.com/sirupsen/logrus"
"sync"
"time"
)
const (
port = 5053
defaultIP = "0.0.0.0"
)
// Server dns server object
type Server struct {
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
server *dns.Server
dnsMux *dns.ServeMux
dnsMuxMap registrationMap
localResolver *localResolver
updateSerial uint64
listenerIsRunning bool
}
type registrationMap map[string]struct{}
type muxUpdate struct {
domain string
handler dns.Handler
}
// NewServer returns a new dns server
func NewServer(ctx context.Context) *Server {
mux := dns.NewServeMux()
dnsServer := &dns.Server{
Addr: fmt.Sprintf("%s:%d", defaultIP, port),
Net: "udp",
Handler: mux,
UDPSize: 65535,
}
ctx, stop := context.WithCancel(ctx)
return &Server{
ctx: ctx,
stop: stop,
server: dnsServer,
dnsMux: mux,
dnsMuxMap: make(registrationMap),
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
}
}
// Start runs the listener in a go routine
func (s *Server) Start() {
log.Debugf("starting dns on %s:%d", defaultIP, port)
go func() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)
err := s.server.ListenAndServe()
if err != nil {
log.Errorf("dns server returned an error: %v", err)
}
}()
}
func (s *Server) setListenerStatus(running bool) {
s.listenerIsRunning = running
}
// Stop stops the server
func (s *Server) Stop() {
s.stop()
err := s.stopListener()
if err != nil {
log.Error(err)
}
}
func (s *Server) stopListener() error {
if !s.listenerIsRunning {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := s.server.ShutdownContext(ctx)
if err != nil {
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
}
return nil
}
// UpdateDNSServer processes an update received from the management service
func (s *Server) UpdateDNSServer(serial uint64, update nbdns.Update) error {
select {
case <-s.ctx.Done():
log.Infof("not updating DNS server as context is closed")
return s.ctx.Err()
default:
if serial < s.updateSerial {
return fmt.Errorf("not applying dns update, error: "+
"network update is %d behind the last applied update", s.updateSerial-serial)
}
s.mux.Lock()
defer s.mux.Unlock()
// is the service should be disabled, we stop the listener
// and proceed with a regular update to clean up the handlers and records
if !update.ServiceEnable {
err := s.stopListener()
if err != nil {
log.Error(err)
}
} else if !s.listenerIsRunning {
s.Start()
}
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err)
}
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err)
}
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
s.updateMux(muxUpdates)
s.updateLocalResolver(localRecords)
s.updateSerial = serial
return nil
}
}
func (s *Server) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
var muxUpdates []muxUpdate
localRecords := make(map[string]nbdns.SimpleRecord, 0)
for _, customZone := range customZones {
if len(customZone.Records) == 0 {
return nil, nil, fmt.Errorf("received an empty list of records")
}
muxUpdates = append(muxUpdates, muxUpdate{
domain: customZone.Domain,
handler: s.localResolver,
})
for _, record := range customZone.Records {
localRecords[record.Name] = record
}
}
return muxUpdates, localRecords, nil
}
func (s *Server) buildUpstreamHandlerUpdate(nameServerGroups []nbdns.NameServerGroup) ([]muxUpdate, error) {
var muxUpdates []muxUpdate
for _, nsGroup := range nameServerGroups {
if len(nsGroup.NameServers) == 0 {
return nil, fmt.Errorf("received a nameserver group with empty nameserver list")
}
handler := &upstreamResolver{
parentCTX: s.ctx,
upstreamClient: &dns.Client{},
upstreamTimeout: defaultUpstreamTimeout,
}
for _, ns := range nsGroup.NameServers {
if ns.NSType != nbdns.UDPNameServerType {
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
continue
}
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
}
if len(handler.upstreamServers) == 0 {
log.Errorf("received a nameserver group with an invalid nameserver list")
continue
}
if nsGroup.Primary {
muxUpdates = append(muxUpdates, muxUpdate{
domain: nbdns.RootZone,
handler: handler,
})
continue
}
if len(nsGroup.Domains) == 0 {
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
}
for _, domain := range nsGroup.Domains {
if domain == "" {
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
}
muxUpdates = append(muxUpdates, muxUpdate{
domain: domain,
handler: handler,
})
}
}
return muxUpdates, nil
}
func (s *Server) updateMux(muxUpdates []muxUpdate) {
muxUpdateMap := make(registrationMap)
for _, update := range muxUpdates {
s.registerMux(update.domain, update.handler)
muxUpdateMap[update.domain] = struct{}{}
}
for key := range s.dnsMuxMap {
_, found := muxUpdateMap[key]
if !found {
s.deregisterMux(key)
}
}
s.dnsMuxMap = muxUpdateMap
}
func (s *Server) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
for key := range s.localResolver.registeredMap {
_, found := update[key]
if !found {
s.localResolver.deleteRecord(key)
}
}
updatedMap := make(registrationMap)
for key, record := range update {
err := s.localResolver.registerRecord(record)
if err != nil {
log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err)
}
updatedMap[key] = struct{}{}
}
s.localResolver.registeredMap = updatedMap
}
func getNSHostPort(ns nbdns.NameServer) string {
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
}
func (s *Server) registerMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
func (s *Server) deregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}

View File

@@ -1,285 +0,0 @@
package dns
import (
"context"
"fmt"
nbdns "github.com/netbirdio/netbird/dns"
"net"
"net/netip"
"os"
"runtime"
"testing"
"time"
)
var zoneRecords = []nbdns.SimpleRecord{
{
Name: "peera.netbird.cloud",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "1.2.3.4",
},
}
func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
}
testCases := []struct {
name string
initUpstreamMap registrationMap
initLocalMap registrationMap
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Update
shouldFail bool
expectedUpstreamMap registrationMap
expectedLocalMap registrationMap
}{
{
name: "Initial Update Should Succeed",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Update{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
{
NameServers: nameServers,
Primary: true,
},
},
},
expectedUpstreamMap: registrationMap{"netbird.io": struct{}{}, "netbird.cloud": struct{}{}, nbdns.RootZone: struct{}{}},
expectedLocalMap: registrationMap{zoneRecords[0].Name: struct{}{}},
},
{
name: "New Update Should Succeed",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
initUpstreamMap: registrationMap{zoneRecords[0].Name: struct{}{}},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Update{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
},
},
expectedUpstreamMap: registrationMap{"netbird.io": struct{}{}, "netbird.cloud": struct{}{}},
expectedLocalMap: registrationMap{zoneRecords[0].Name: struct{}{}},
},
{
name: "Smaller Update Serial Should Be Skipped",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap),
initSerial: 2,
inputSerial: 1,
shouldFail: true,
},
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Update{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []nbdns.NameServerGroup{
{
NameServers: nameServers,
},
},
},
shouldFail: true,
},
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Update{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []nbdns.NameServerGroup{
{
NameServers: nameServers,
},
},
},
shouldFail: true,
},
{
name: "Invalid Custom Zone Records list Should Fail",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Update{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
},
},
NameServerGroups: []nbdns.NameServerGroup{
{
NameServers: nameServers,
Primary: true,
},
},
},
shouldFail: true,
},
{
name: "Empty Update Should Succeed and Clean Maps",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
initUpstreamMap: registrationMap{zoneRecords[0].Name: struct{}{}},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Update{ServiceEnable: true},
expectedUpstreamMap: make(registrationMap),
expectedLocalMap: make(registrationMap),
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
ctx := context.Background()
dnsServer := NewServer(ctx)
dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.registeredMap = testCase.initLocalMap
dnsServer.updateSerial = testCase.initSerial
dnsServer.listenerIsRunning = true
err := dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
if err != nil {
if testCase.shouldFail {
return
}
t.Fatalf("update dns server should not fail, got error: %v", err)
}
if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) {
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap))
}
for key := range testCase.expectedUpstreamMap {
_, found := dnsServer.dnsMuxMap[key]
if !found {
t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap)
}
}
if len(dnsServer.localResolver.registeredMap) != len(testCase.expectedLocalMap) {
t.Fatalf("update local failed, registered map size is different than expected, want %d, got %d", len(testCase.expectedLocalMap), len(dnsServer.localResolver.registeredMap))
}
for key := range testCase.expectedLocalMap {
_, found := dnsServer.localResolver.registeredMap[key]
if !found {
t.Fatalf("update local failed, key %s was not found in the localResolver.registeredMap: %#v", key, dnsServer.localResolver.registeredMap)
}
}
})
}
}
func TestDNSServerStartStop(t *testing.T) {
ctx := context.Background()
dnsServer := NewServer(ctx)
if runtime.GOOS == "windows" && os.Getenv("CI") == "true" {
// todo review why this test is not working only on github actions workflows
t.Skip("skipping test in Windows CI workflows.")
}
dnsServer.Start()
err := dnsServer.localResolver.registerRecord(zoneRecords[0])
if err != nil {
t.Error(err)
}
dnsServer.dnsMux.Handle("netbird.cloud", dnsServer.localResolver)
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{
Timeout: time.Second * 5,
}
addr := fmt.Sprintf("127.0.0.1:%d", port)
conn, err := d.DialContext(ctx, network, addr)
if err != nil {
t.Log(err)
// retry test before exit, for slower systems
return d.DialContext(ctx, network, addr)
}
return conn, nil
},
}
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
if err != nil {
t.Fatalf("failed to connect to the server, error: %v", err)
}
t.Log(ips)
if ips[0] != zoneRecords[0].RData {
t.Fatalf("got a different IP from the server: want %s, got %s", zoneRecords[0].RData, ips[0])
}
dnsServer.Stop()
ctx, cancel := context.WithTimeout(ctx, time.Second*1)
defer cancel()
_, err = resolver.LookupHost(ctx, zoneRecords[0].Name)
if err == nil {
t.Fatalf("we should encounter an error when querying a stopped server")
}
}

View File

@@ -1,67 +0,0 @@
package dns
import (
"context"
"errors"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"net"
"time"
)
const defaultUpstreamTimeout = 15 * time.Second
type upstreamResolver struct {
parentCTX context.Context
upstreamClient *dns.Client
upstreamServers []string
upstreamTimeout time.Duration
}
// ServeDNS handles a DNS request
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
log.Tracef("received an upstream question: %#v", r.Question[0])
select {
case <-u.parentCTX.Done():
return
default:
}
for _, upstream := range u.upstreamServers {
ctx, cancel := context.WithTimeout(u.parentCTX, u.upstreamTimeout)
rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream)
cancel()
if err != nil {
if err == context.DeadlineExceeded || isTimeout(err) {
log.Warnf("got an error while connecting to upstream %s, error: %v", upstream, err)
continue
}
log.Errorf("got an error while querying the upstream %s, error: %v", upstream, err)
return
}
log.Tracef("took %s to query the upstream %s", t, upstream)
err = w.WriteMsg(rm)
if err != nil {
log.Errorf("got an error while writing the upstream resolver response, error: %v", err)
}
return
}
log.Errorf("all queries to the upstream nameservers failed with timeout")
}
// isTimeout returns true if the given error is a network timeout error.
//
// Copied from k8s.io/apimachinery/pkg/util/net.IsTimeout
func isTimeout(err error) bool {
var neterr net.Error
if errors.As(err, &neterr) {
return neterr != nil && neterr.Timeout()
}
return false
}

View File

@@ -1,110 +0,0 @@
package dns
import (
"context"
"github.com/miekg/dns"
"strings"
"testing"
"time"
)
func TestUpstreamResolver_ServeDNS(t *testing.T) {
testCases := []struct {
name string
inputMSG *dns.Msg
responseShouldBeNil bool
InputServers []string
timeout time.Duration
cancelCTX bool
expectedAnswer string
}{
{
name: "Should Resolve A Record",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"},
timeout: defaultUpstreamTimeout,
expectedAnswer: "1.1.1.1",
},
{
name: "Should Resolve If First Upstream Times Out",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
timeout: 2 * time.Second,
expectedAnswer: "1.1.1.1",
},
{
name: "Should Not Resolve If Can't Connect To Both Servers",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.0.0.0:53", "8.0.0.1:53"},
timeout: 200 * time.Millisecond,
responseShouldBeNil: true,
},
{
name: "Should Not Resolve If Parent Context Is Canceled",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
cancelCTX: true,
timeout: defaultUpstreamTimeout,
responseShouldBeNil: true,
},
//{
// name: "Should Resolve CNAME Record",
// inputMSG: new(dns.Msg).SetQuestion("one.one.one.one", dns.TypeCNAME),
//},
//{
// name: "Should Not Write When Not Found A Record",
// inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
// responseShouldBeNil: true,
//},
}
// should resolve if first upstream times out
// should not write when both fails
// should not resolve if parent context is canceled
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
resolver := &upstreamResolver{
parentCTX: ctx,
upstreamClient: &dns.Client{},
upstreamServers: testCase.InputServers,
upstreamTimeout: testCase.timeout,
}
if testCase.cancelCTX {
cancel()
} else {
defer cancel()
}
var responseMSG *dns.Msg
responseWriter := &mockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, testCase.inputMSG)
if responseMSG == nil {
if testCase.responseShouldBeNil {
return
}
t.Fatalf("should write a response message")
}
foundAnswer := false
for _, answer := range responseMSG.Answer {
if strings.Contains(answer.String(), testCase.expectedAnswer) {
foundAnswer = true
break
}
}
if !foundAnswer {
t.Errorf("couldn't find the required answer, %s, in the dns response", testCase.expectedAnswer)
}
})
}
}

View File

@@ -3,13 +3,11 @@ package internal
import (
"context"
"fmt"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager"
nbssh "github.com/netbirdio/netbird/client/ssh"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/route"
"math/rand"
"net"
"reflect"
"runtime"
"strings"
@@ -90,10 +88,7 @@ type Engine struct {
wgInterface *iface.WGIface
udpMux ice.UDPMux
udpMuxSrflx ice.UniversalUDPMux
udpMuxConn *net.UDPConn
udpMuxConnSrflx *net.UDPConn
iceMux ice.UniversalUDPMux
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64
@@ -104,8 +99,6 @@ type Engine struct {
statusRecorder *nbstatus.Status
routeManager routemanager.Manager
dnsServer *dns.Server
}
// Peer is an instance of the Connection Peer
@@ -133,7 +126,6 @@ func NewEngine(
networkSerial: 0,
sshServerFunc: nbssh.DefaultSSHServer,
statusRecorder: statusRecorder,
dnsServer: dns.NewServer(ctx),
}
}
@@ -159,30 +151,6 @@ func (e *Engine) Stop() error {
}
}
if e.udpMux != nil {
if err := e.udpMux.Close(); err != nil {
log.Debugf("close udp mux: %v", err)
}
}
if e.udpMuxSrflx != nil {
if err := e.udpMuxSrflx.Close(); err != nil {
log.Debugf("close server reflexive udp mux: %v", err)
}
}
if e.udpMuxConn != nil {
if err := e.udpMuxConn.Close(); err != nil {
log.Debugf("close udp mux connection: %v", err)
}
}
if e.udpMuxConnSrflx != nil {
if err := e.udpMuxConnSrflx.Close(); err != nil {
log.Debugf("close server reflexive udp mux connection: %v", err)
}
}
if !isNil(e.sshServer) {
err := e.sshServer.Stop()
if err != nil {
@@ -194,10 +162,6 @@ func (e *Engine) Stop() error {
e.routeManager.Stop()
}
if e.dnsServer != nil {
e.dnsServer.Stop()
}
log.Infof("stopped Netbird Engine")
return nil
@@ -221,34 +185,34 @@ func (e *Engine) Start() error {
return err
}
e.udpMuxConn, err = net.ListenUDP("udp4", &net.UDPAddr{Port: e.config.UDPMuxPort})
if err != nil {
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error())
return err
}
e.udpMuxConnSrflx, err = net.ListenUDP("udp4", &net.UDPAddr{Port: e.config.UDPMuxSrflxPort})
if err != nil {
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxSrflxPort, err.Error())
return err
}
e.udpMux = ice.NewUDPMuxDefault(ice.UDPMuxParams{UDPConn: e.udpMuxConn})
e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx})
err = e.wgInterface.Create()
bind := &iface.ICEBind{}
err = e.wgInterface.CreateNew(bind)
if err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", wgIfaceName, err.Error())
return err
}
err = e.wgInterface.Configure(myPrivateKey.String(), e.config.WgPort)
port, err := e.wgInterface.GetListenPort()
if err != nil {
return err
}
err = e.wgInterface.Configure(myPrivateKey.String(), *port)
if err != nil {
log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIfaceName, err.Error())
return err
}
iceMux, err := bind.GetICEMux()
if err != nil {
return err
}
e.iceMux = iceMux
log.Infof("NetBird Engine started listening on WireGuard port %d", *port)
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder)
e.config.WgPort = *port
e.receiveSignalEvents()
e.receiveManagementEvents()
@@ -768,8 +732,8 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
StunTurn: stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
Timeout: timeout,
UDPMux: e.udpMux,
UDPMuxSrflx: e.udpMuxSrflx,
UDPMux: e.iceMux,
UDPMuxSrflx: e.iceMux,
ProxyConfig: proxyConfig,
LocalWgPort: e.config.WgPort,
}

View File

@@ -761,12 +761,12 @@ func startManagement(port int, dataDir string) (*grpc.Server, error) {
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
}
peersUpdateManager := server.NewPeersUpdateManager()
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "")
accountManager, err := server.BuildManager(store, peersUpdateManager, nil)
if err != nil {
return nil, err
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager)
if err != nil {
return nil, err
}

View File

@@ -36,10 +36,7 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string
defer func() {
err = mgmClient.Close()
if err != nil {
cStatus, ok := status.FromError(err)
if !ok || ok && cStatus.Code() != codes.Canceled {
log.Warnf("failed to close the Management service client, err: %v", err)
}
log.Warnf("failed to close the Management service client %v", err)
}
}()

View File

@@ -147,7 +147,7 @@ func (conn *Conn) reCreateAgent() error {
MulticastDNSMode: ice.MulticastDNSModeDisabled,
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4},
Urls: conn.config.StunTurn,
CandidateTypes: []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay},
CandidateTypes: []ice.CandidateType{ice.CandidateTypeServerReflexive, ice.CandidateTypeHost, ice.CandidateTypeRelay},
FailedTimeout: &failedTimeout,
InterfaceFilter: interfaceFilter(conn.config.InterfaceBlackList),
UDPMux: conn.config.UDPMux,
@@ -280,14 +280,7 @@ func (conn *Conn) Open() error {
return err
}
if conn.proxy.Type() == proxy.TypeNoProxy {
host, _, _ := net.SplitHostPort(remoteConn.LocalAddr().String())
rhost, _, _ := net.SplitHostPort(remoteConn.RemoteAddr().String())
// direct Wireguard connection
log.Infof("directly connected to peer %s [laddr <-> raddr] [%s:%d <-> %s:%d]", conn.config.Key, host, iface.DefaultWgPort, rhost, iface.DefaultWgPort)
} else {
log.Infof("connected to peer %s [laddr <-> raddr] [%s <-> %s]", conn.config.Key, remoteConn.LocalAddr().String(), remoteConn.RemoteAddr().String())
}
log.Infof("connected to peer %s [laddr <-> raddr] [%s <-> %s]", conn.config.Key, remoteConn.LocalAddr().String(), remoteConn.RemoteAddr().String())
// wait until connection disconnected or has been closed externally (upper layer, e.g. engine)
select {
@@ -351,15 +344,16 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
}
peerState := nbStatus.PeerState{PubKey: conn.config.Key}
useProxy := shouldUseProxy(pair)
var p proxy.Proxy
if useProxy {
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
p = proxy.NewWireguardProxy(conn.config.ProxyConfig)
peerState.Direct = false
} else {
p = proxy.NewNoProxy(conn.config.ProxyConfig, remoteWgPort)
peerState.Direct = true
}
conn.proxy = p
err = p.Start(remoteConn)
if err != nil {

View File

@@ -39,7 +39,6 @@ func (p *NoProxy) Start(remoteConn net.Conn) error {
if err != nil {
return err
}
addr.Port = p.RemoteWgListenPort
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
addr, p.config.PreSharedKey)

View File

@@ -207,7 +207,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
err = addToRouteTableIfNoExists(c.network, c.wgInterface.GetAddress().IP.String())
if err != nil {
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
c.network.String(), c.wgInterface.GetAddress().IP.String(), err)
c.chosenRoute.Network.String(), c.wgInterface.GetAddress().IP.String(), err)
}
}

View File

@@ -9,16 +9,14 @@ import (
import "github.com/google/nftables"
const (
ipv6Forwarding = "netbird-rt-ipv6-forwarding"
ipv4Forwarding = "netbird-rt-ipv4-forwarding"
ipv6Nat = "netbird-rt-ipv6-nat"
ipv4Nat = "netbird-rt-ipv4-nat"
natFormat = "netbird-nat-%s"
forwardingFormat = "netbird-fwd-%s"
inNatFormat = "netbird-nat-in-%s"
inForwardingFormat = "netbird-fwd-in-%s"
ipv6 = "ipv6"
ipv4 = "ipv4"
ipv6Forwarding = "netbird-rt-ipv6-forwarding"
ipv4Forwarding = "netbird-rt-ipv4-forwarding"
ipv6Nat = "netbird-rt-ipv6-nat"
ipv4Nat = "netbird-rt-ipv4-nat"
natFormat = "netbird-nat-%s"
forwardingFormat = "netbird-fwd-%s"
ipv6 = "ipv6"
ipv4 = "ipv4"
)
func genKey(format string, input string) string {
@@ -55,13 +53,3 @@ func NewFirewall(parentCTX context.Context) firewallManager {
return manager
}
func getInPair(pair routerPair) routerPair {
return routerPair{
ID: pair.ID,
// invert source/destination
source: pair.destination,
destination: pair.source,
masquerade: pair.masquerade,
}
}

View File

@@ -311,37 +311,7 @@ func (i *iptablesManager) InsertRoutingRules(pair routerPair) error {
i.mux.Lock()
defer i.mux.Unlock()
err := i.insertRoutingRule(forwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, routingFinalForwardJump, pair)
if err != nil {
return err
}
err = i.insertRoutingRule(inForwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, routingFinalForwardJump, getInPair(pair))
if err != nil {
return err
}
if !pair.masquerade {
return nil
}
err = i.insertRoutingRule(natFormat, iptablesNatTable, iptablesRoutingNatChain, routingFinalNatJump, pair)
if err != nil {
return err
}
err = i.insertRoutingRule(inNatFormat, iptablesNatTable, iptablesRoutingNatChain, routingFinalNatJump, getInPair(pair))
if err != nil {
return err
}
return nil
}
// insertRoutingRule inserts an iptable rule
func (i *iptablesManager) insertRoutingRule(keyFormat, table, chain, jump string, pair routerPair) error {
var err error
prefix := netip.MustParsePrefix(pair.source)
ipVersion := ipv4
iptablesClient := i.ipv4Client
@@ -350,22 +320,43 @@ func (i *iptablesManager) insertRoutingRule(keyFormat, table, chain, jump string
ipVersion = ipv6
}
ruleKey := genKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, ruleKey, pair.source, pair.destination)
existingRule, found := i.rules[ipVersion][ruleKey]
forwardRuleKey := genKey(forwardingFormat, pair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, pair.source, pair.destination)
existingRule, found := i.rules[ipVersion][forwardRuleKey]
if found {
err = iptablesClient.DeleteIfExists(table, chain, existingRule...)
err = iptablesClient.DeleteIfExists(iptablesFilterTable, iptablesRoutingForwardingChain, existingRule...)
if err != nil {
return fmt.Errorf("iptables: error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
return fmt.Errorf("iptables: error while removing existing forwarding rule for %s: %v", pair.destination, err)
}
delete(i.rules[ipVersion], ruleKey)
delete(i.rules[ipVersion], forwardRuleKey)
}
err = iptablesClient.Insert(table, chain, 1, rule...)
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...)
if err != nil {
return fmt.Errorf("iptables: error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
return fmt.Errorf("iptables: error while adding new forwarding rule for %s: %v", pair.destination, err)
}
i.rules[ipVersion][ruleKey] = rule
i.rules[ipVersion][forwardRuleKey] = forwardRule
if !pair.masquerade {
return nil
}
natRuleKey := genKey(natFormat, pair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, pair.source, pair.destination)
existingRule, found = i.rules[ipVersion][natRuleKey]
if found {
err = iptablesClient.DeleteIfExists(iptablesNatTable, iptablesRoutingNatChain, existingRule...)
if err != nil {
return fmt.Errorf("iptables: error while removing existing nat rulefor %s: %v", pair.destination, err)
}
delete(i.rules[ipVersion], natRuleKey)
}
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...)
if err != nil {
return fmt.Errorf("iptables: error while adding new nat rulefor %s: %v", pair.destination, err)
}
i.rules[ipVersion][natRuleKey] = natRule
return nil
}
@@ -375,37 +366,7 @@ func (i *iptablesManager) RemoveRoutingRules(pair routerPair) error {
i.mux.Lock()
defer i.mux.Unlock()
err := i.removeRoutingRule(forwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(inForwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, getInPair(pair))
if err != nil {
return err
}
if !pair.masquerade {
return nil
}
err = i.removeRoutingRule(natFormat, iptablesNatTable, iptablesRoutingNatChain, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(inNatFormat, iptablesNatTable, iptablesRoutingNatChain, getInPair(pair))
if err != nil {
return err
}
return nil
}
// removeRoutingRule removes an iptables rule
func (i *iptablesManager) removeRoutingRule(keyFormat, table, chain string, pair routerPair) error {
var err error
prefix := netip.MustParsePrefix(pair.source)
ipVersion := ipv4
iptablesClient := i.ipv4Client
@@ -414,23 +375,29 @@ func (i *iptablesManager) removeRoutingRule(keyFormat, table, chain string, pair
ipVersion = ipv6
}
ruleKey := genKey(keyFormat, pair.ID)
existingRule, found := i.rules[ipVersion][ruleKey]
forwardRuleKey := genKey(forwardingFormat, pair.ID)
existingRule, found := i.rules[ipVersion][forwardRuleKey]
if found {
err = iptablesClient.DeleteIfExists(table, chain, existingRule...)
err = iptablesClient.DeleteIfExists(iptablesFilterTable, iptablesRoutingForwardingChain, existingRule...)
if err != nil {
return fmt.Errorf("iptables: error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
return fmt.Errorf("iptables: error while removing existing forwarding rule for %s: %v", pair.destination, err)
}
}
delete(i.rules[ipVersion], ruleKey)
delete(i.rules[ipVersion], forwardRuleKey)
if !pair.masquerade {
return nil
}
natRuleKey := genKey(natFormat, pair.ID)
existingRule, found = i.rules[ipVersion][natRuleKey]
if found {
err = iptablesClient.DeleteIfExists(iptablesNatTable, iptablesRoutingNatChain, existingRule...)
if err != nil {
return fmt.Errorf("iptables: error while removing existing nat rule for %s: %v", pair.destination, err)
}
}
delete(i.rules[ipVersion], natRuleKey)
return nil
}
func getIptablesRuleType(table string) string {
ruleType := "forwarding"
if table == iptablesNatTable {
ruleType = "nat"
}
return ruleType
}

View File

@@ -159,17 +159,6 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
require.True(t, found, "forwarding rule should exist in the manager map")
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
exists, err = iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
require.True(t, exists, "income forwarding rule should exist")
foundRule, found = manager.rules[testCase.ipVersion][inForwardRuleKey]
require.True(t, found, "income forwarding rule should exist in the manager map")
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
@@ -183,23 +172,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
} else {
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[testCase.ipVersion][natRuleKey]
require.False(t, foundNat, "nat rule should not exist in the map")
}
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
if testCase.inputPair.masquerade {
require.True(t, exists, "income nat rule should be created")
foundNatRule, foundNat := manager.rules[testCase.ipVersion][inNatRuleKey]
require.True(t, foundNat, "income nat rule should exist in the map")
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
} else {
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[testCase.ipVersion][inNatRuleKey]
require.False(t, foundNat, "income nat rule should not exist in the map")
require.False(t, foundNat, "nat rule should exist in the map")
}
})
}
@@ -240,24 +213,12 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...)
require.NoError(t, err, "inserting rule should not return error")
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, inForwardRule...)
require.NoError(t, err, "inserting rule should not return error")
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error")
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error")
delete(manager.rules, ipv4)
delete(manager.rules, ipv6)
@@ -274,26 +235,12 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
_, found := manager.rules[testCase.ipVersion][forwardRuleKey]
require.False(t, found, "forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
require.False(t, exists, "income forwarding rule should not exist")
_, found = manager.rules[testCase.ipVersion][inForwardRuleKey]
require.False(t, found, "income forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
require.False(t, exists, "nat rule should not exist")
_, found = manager.rules[testCase.ipVersion][natRuleKey]
require.False(t, found, "nat rule should exist in the manager map")
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
require.False(t, exists, "income nat rule should not exist")
_, found = manager.rules[testCase.ipVersion][inNatRuleKey]
require.False(t, found, "income nat rule should exist in the manager map")
require.False(t, found, "forwarding rule should exist in the manager map")
})
}

View File

@@ -12,6 +12,7 @@ import (
)
import "github.com/google/nftables"
//
const (
nftablesTable = "netbird-rt"
nftablesRoutingForwardingChain = "netbird-rt-fwd"
@@ -83,10 +84,8 @@ func (n *nftablesManager) CleanRoutingRules() {
n.mux.Lock()
defer n.mux.Unlock()
log.Debug("flushing tables")
if n.tableIPv4 != nil && n.tableIPv6 != nil {
n.conn.FlushTable(n.tableIPv6)
n.conn.FlushTable(n.tableIPv4)
}
n.conn.FlushTable(n.tableIPv6)
n.conn.FlushTable(n.tableIPv4)
log.Debugf("flushing tables result in: %v error", n.conn.Flush())
}
@@ -247,77 +246,53 @@ func (n *nftablesManager) InsertRoutingRules(pair routerPair) error {
n.mux.Lock()
defer n.mux.Unlock()
err := n.refreshRulesMap()
if err != nil {
return err
}
err = n.insertRoutingRule(forwardingFormat, nftablesRoutingForwardingChain, pair, false)
if err != nil {
return err
}
err = n.insertRoutingRule(inForwardingFormat, nftablesRoutingForwardingChain, getInPair(pair), false)
if err != nil {
return err
}
if pair.masquerade {
err = n.insertRoutingRule(natFormat, nftablesRoutingNatChain, pair, true)
if err != nil {
return err
}
err = n.insertRoutingRule(inNatFormat, nftablesRoutingNatChain, getInPair(pair), true)
if err != nil {
return err
}
}
err = n.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
}
return nil
}
// insertRoutingRule inserts a nftable rule to the conn client flush queue
func (n *nftablesManager) insertRoutingRule(format, chain string, pair routerPair, isNat bool) error {
prefix := netip.MustParsePrefix(pair.source)
sourceExp := generateCIDRMatcherExpressions("source", pair.source)
destExp := generateCIDRMatcherExpressions("destination", pair.destination)
var expression []expr.Any
if isNat {
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...)
fwdKey := genKey(forwardingFormat, pair.ID)
if prefix.Addr().Unmap().Is4() {
n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv4,
Chain: n.chains[ipv4][nftablesRoutingForwardingChain],
Exprs: forwardExp,
UserData: []byte(fwdKey),
})
} else {
expression = append(sourceExp, append(destExp, exprCounterAccept...)...)
n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv6,
Chain: n.chains[ipv6][nftablesRoutingForwardingChain],
Exprs: forwardExp,
UserData: []byte(fwdKey),
})
}
ruleKey := genKey(format, pair.ID)
if pair.masquerade {
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
natKey := genKey(natFormat, pair.ID)
_, exists := n.rules[ruleKey]
if exists {
err := n.removeRoutingRule(format, pair)
if err != nil {
return err
if prefix.Addr().Unmap().Is4() {
n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv4,
Chain: n.chains[ipv4][nftablesRoutingNatChain],
Exprs: natExp,
UserData: []byte(natKey),
})
} else {
n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv6,
Chain: n.chains[ipv6][nftablesRoutingNatChain],
Exprs: natExp,
UserData: []byte(natKey),
})
}
}
if prefix.Addr().Unmap().Is4() {
n.rules[ruleKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv4,
Chain: n.chains[ipv4][chain],
Exprs: expression,
UserData: []byte(ruleKey),
})
} else {
n.rules[ruleKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv6,
Chain: n.chains[ipv6][chain],
Exprs: expression,
UserData: []byte(ruleKey),
})
err := n.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
}
return nil
}
@@ -332,26 +307,26 @@ func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
return err
}
err = n.removeRoutingRule(forwardingFormat, pair)
if err != nil {
return err
fwdKey := genKey(forwardingFormat, pair.ID)
natKey := genKey(natFormat, pair.ID)
fwdRule, found := n.rules[fwdKey]
if found {
err = n.conn.DelRule(fwdRule)
if err != nil {
return fmt.Errorf("nftables: unable to remove forwarding rule for %s: %v", pair.destination, err)
}
log.Debugf("nftables: removing forwarding rule for %s", pair.destination)
delete(n.rules, fwdKey)
}
err = n.removeRoutingRule(inForwardingFormat, getInPair(pair))
if err != nil {
return err
natRule, found := n.rules[natKey]
if found {
err = n.conn.DelRule(natRule)
if err != nil {
return fmt.Errorf("nftables: unable to remove nat rule for %s: %v", pair.destination, err)
}
log.Debugf("nftables: removing nat rule for %s", pair.destination)
delete(n.rules, natKey)
}
err = n.removeRoutingRule(natFormat, pair)
if err != nil {
return err
}
err = n.removeRoutingRule(inNatFormat, getInPair(pair))
if err != nil {
return err
}
err = n.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
@@ -360,29 +335,6 @@ func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
return nil
}
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) error {
ruleKey := genKey(format, pair.ID)
rule, found := n.rules[ruleKey]
if found {
ruleType := "forwarding"
if rule.Chain.Type == nftables.ChainTypeNAT {
ruleType = "nat"
}
err := n.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.destination, err)
}
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.destination)
delete(n.rules, ruleKey)
}
return nil
}
// getPayloadDirectives get expression directives based on ip version and direction
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
switch {

View File

@@ -189,45 +189,6 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
sourceExp = generateCIDRMatcherExpressions("source", getInPair(testCase.inputPair).source)
destExp = generateCIDRMatcherExpressions("destination", getInPair(testCase.inputPair).destination)
testingExpression = append(sourceExp, destExp...)
inFwdRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
found = 0
for _, registeredChains := range manager.chains {
for _, chain := range registeredChains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
found = 1
}
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.inputPair.masquerade {
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
found := 0
for _, registeredChains := range manager.chains {
for _, chain := range registeredChains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
found = 1
}
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
})
}
}
@@ -280,28 +241,6 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
UserData: []byte(natRuleKey),
})
sourceExp = generateCIDRMatcherExpressions("source", getInPair(testCase.inputPair).source)
destExp = generateCIDRMatcherExpressions("destination", getInPair(testCase.inputPair).destination)
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...)
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: table,
Chain: manager.chains[testCase.ipVersion][nftablesRoutingForwardingChain],
Exprs: forwardExp,
UserData: []byte(inForwardRuleKey),
})
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: table,
Chain: manager.chains[testCase.ipVersion][nftablesRoutingNatChain],
Exprs: natExp,
UserData: []byte(inNatRuleKey),
})
err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error")
@@ -320,10 +259,8 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 {
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should exist")
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should exist")
}
}
}

View File

@@ -21,7 +21,7 @@ func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
}
if prefixGateway != nil && !prefixGateway.Equal(gateway) {
log.Warnf("skipping adding a new route for network %s because it already exists and is pointing to the non default gateway: %s", prefix, prefixGateway)
log.Warnf("route for network %s already exist and is pointing to the gateway: %s, won't add another one", prefix, prefixGateway)
return nil
}
return addToRouteTable(prefix, addr)
@@ -45,14 +45,11 @@ func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) {
if err != nil {
return nil, err
}
_, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice())
_, _, localGatewayAddress, err := r.Route(prefix.Addr().AsSlice())
if err != nil {
log.Errorf("getting routes returned an error: %v", err)
return nil, errRouteNotFound
}
if gateway == nil {
return preferredSrc, nil
}
return gateway, nil
return localGatewayAddress, nil
}

View File

@@ -4,7 +4,6 @@ import (
"fmt"
"github.com/netbirdio/netbird/iface"
"github.com/stretchr/testify/require"
"net"
"net/netip"
"testing"
)
@@ -67,45 +66,3 @@ func TestAddRemoveRoutes(t *testing.T) {
})
}
}
func TestGetExistingRIBRouteGateway(t *testing.T) {
gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err)
}
if gateway == nil {
t.Fatal("should return a gateway")
}
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var testingIP string
var testingPrefix netip.Prefix
for _, address := range addresses {
if address.Network() != "ip+net" {
continue
}
prefix := netip.MustParsePrefix(address.String())
if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() {
testingIP = prefix.Addr().String()
testingPrefix = prefix.Masked()
break
}
}
localIP, err := getExistingRIBRouteGateway(testingPrefix)
if err != nil {
t.Fatal("shouldn't return error: ", err)
}
if localIP == nil {
t.Fatal("should return a gateway for local network")
}
if localIP.String() == gateway.String() {
t.Fatal("local ip should not match with gateway IP")
}
if localIP.String() != testingIP {
t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String())
}
}

View File

@@ -1,17 +1,36 @@
package system
import (
"bytes"
"context"
"fmt"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows/registry"
"os"
"os/exec"
"runtime"
"strings"
)
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
ver := getOSVersion()
cmd := exec.Command("cmd", "ver")
cmd.Stdin = strings.NewReader("some")
var out bytes.Buffer
var stderr bytes.Buffer
cmd.Stdout = &out
cmd.Stderr = &stderr
err := cmd.Run()
if err != nil {
panic(err)
}
osStr := strings.Replace(out.String(), "\n", "", -1)
osStr = strings.Replace(osStr, "\r\n", "", -1)
tmp1 := strings.Index(osStr, "[Version")
tmp2 := strings.Index(osStr, "]")
var ver string
if tmp1 == -1 || tmp2 == -1 {
ver = "unknown"
} else {
ver = osStr[tmp1+9 : tmp2]
}
gio := &Info{Kernel: "windows", OSVersion: ver, Core: ver, Platform: "unknown", OS: "windows", GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
gio.Hostname, _ = os.Hostname()
gio.WiretrusteeVersion = NetbirdVersion()
@@ -19,37 +38,3 @@ func GetInfo(ctx context.Context) *Info {
return gio
}
func getOSVersion() string {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE)
if err != nil {
log.Error(err)
return "0.0.0.0"
}
defer func() {
deferErr := k.Close()
if deferErr != nil {
log.Error(deferErr)
}
}()
major, _, err := k.GetIntegerValue("CurrentMajorVersionNumber")
if err != nil {
log.Error(err)
}
minor, _, err := k.GetIntegerValue("CurrentMinorVersionNumber")
if err != nil {
log.Error(err)
}
build, _, err := k.GetStringValue("CurrentBuildNumber")
if err != nil {
log.Error(err)
}
// Update Build Revision
ubr, _, err := k.GetIntegerValue("UBR")
if err != nil {
log.Error(err)
}
ver := fmt.Sprintf("%d.%d.%s.%d", major, minor, build, ubr)
return ver
}

View File

@@ -8,6 +8,7 @@ import (
"context"
"flag"
"fmt"
"github.com/netbirdio/netbird/client/system"
"os"
"os/exec"
"path"
@@ -17,8 +18,6 @@ import (
"syscall"
"time"
"github.com/netbirdio/netbird/client/system"
"github.com/cenkalti/backoff/v4"
_ "embed"
@@ -62,8 +61,6 @@ func main() {
flag.Parse()
a := app.New()
a.SetIcon(fyne.NewStaticResource("netbird", iconDisconnectedPNG))
client := newServiceClient(daemonAddr, a, showSettings)
if showSettings {
a.Run()
@@ -116,7 +113,7 @@ type serviceClient struct {
iLogFile *widget.Entry
iPreSharedKey *widget.Entry
// observable settings over corresponding iMngURL and iPreSharedKey values.
// observable settings over correspondign iMngURL and iPreSharedKey values.
managementURL string
preSharedKey string
adminURL string
@@ -124,7 +121,7 @@ type serviceClient struct {
// newServiceClient instance constructor
//
// This constructor also builds the UI elements for the settings window.
// This constructor olso build UI elements for settings window.
func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient {
s := &serviceClient{
ctx: context.Background(),
@@ -152,7 +149,7 @@ func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient
func (s *serviceClient) showUIElements() {
// add settings window UI elements.
s.wSettings = s.app.NewWindow("NetBird Settings")
s.wSettings = s.app.NewWindow("Settings")
s.iMngURL = widget.NewEntry()
s.iAdminURL = widget.NewEntry()
s.iConfigFile = widget.NewEntry()
@@ -327,15 +324,13 @@ func (s *serviceClient) updateStatus() error {
return err
}
if status.Status == string(internal.StatusConnected) && !s.mUp.Disabled() {
if status.Status == string(internal.StatusConnected) {
systray.SetIcon(s.icConnected)
systray.SetTooltip("NetBird (Connected)")
s.mStatus.SetTitle("Connected")
s.mUp.Disable()
s.mDown.Enable()
} else if status.Status != string(internal.StatusConnected) && s.mUp.Disabled() {
} else {
systray.SetIcon(s.icDisconnected)
systray.SetTooltip("NetBird (Disconnected)")
s.mStatus.SetTitle("Disconnected")
s.mDown.Disable()
s.mUp.Enable()
@@ -360,7 +355,6 @@ func (s *serviceClient) updateStatus() error {
func (s *serviceClient) onTrayReady() {
systray.SetIcon(s.icDisconnected)
systray.SetTooltip("NetBird")
// setup systray menu items
s.mStatus = systray.AddMenuItem("Disconnected", "Disconnected")

View File

@@ -1,56 +0,0 @@
// Package dns implement dns types and standard methods and functions
// to parse and normalize dns records and configuration
package dns
import (
"fmt"
"github.com/miekg/dns"
)
const (
// DefaultDNSPort well-known port number
DefaultDNSPort = 53
// RootZone is a string representation of the root zone
RootZone = "."
// DefaultClass is the class supported by the system
DefaultClass = "IN"
)
// Update represents a dns update that is exchanged between management and peers
type Update struct {
// ServiceEnable indicates if the service should be enabled
ServiceEnable bool
// NameServerGroups contains a list of nameserver group
NameServerGroups []NameServerGroup
// CustomZones contains a list of custom zone
CustomZones []CustomZone
}
// CustomZone represents a custom zone to be resolved by the dns server
type CustomZone struct {
// Domain is the zone's domain
Domain string
// Records custom zone records
Records []SimpleRecord
}
// SimpleRecord provides a simple DNS record specification for CNAME, A and AAAA records
type SimpleRecord struct {
// Name domain name
Name string
// Type of record, 1 for A, 5 for CNAME, 28 for AAAA. see https://pkg.go.dev/github.com/miekg/dns@v1.1.41#pkg-constants
Type int
// Class dns class, currently use the DefaultClass for all records
Class string
// TTL time-to-live for the record
TTL int
// RData is the actual value resolved in a dns query
RData string
}
// String returns a string of the simple record formatted as:
// <Name> <TTL> <Class> <Type> <RDATA>
func (s SimpleRecord) String() string {
fqdn := dns.Fqdn(s.Name)
return fmt.Sprintf("%s %d %s %s %s", fqdn, s.TTL, s.Class, dns.Type(s.Type).String(), s.RData)
}

View File

@@ -1,192 +0,0 @@
package dns
import (
"fmt"
"net/netip"
"net/url"
"strconv"
"strings"
)
const (
// InvalidNameServerType invalid nameserver type
InvalidNameServerType NameServerType = iota
// UDPNameServerType udp nameserver type
UDPNameServerType
)
const (
// MaxGroupNameChar maximum group name size
MaxGroupNameChar = 40
// InvalidNameServerTypeString invalid nameserver type as string
InvalidNameServerTypeString = "invalid"
// UDPNameServerTypeString udp nameserver type as string
UDPNameServerTypeString = "udp"
)
// NameServerType nameserver type
type NameServerType int
// String returns nameserver type string
func (n NameServerType) String() string {
switch n {
case UDPNameServerType:
return UDPNameServerTypeString
default:
return InvalidNameServerTypeString
}
}
// ToNameServerType returns a nameserver type
func ToNameServerType(typeString string) NameServerType {
switch typeString {
case UDPNameServerTypeString:
return UDPNameServerType
default:
return InvalidNameServerType
}
}
// NameServerGroup group of nameservers and with group ids
type NameServerGroup struct {
// ID identifier of group
ID string
// Name group name
Name string
// Description group description
Description string
// NameServers list of nameservers
NameServers []NameServer
// Groups list of peer group IDs to distribute the nameservers information
Groups []string
// Primary indicates that the nameserver group is the primary resolver for any dns query
Primary bool
// Domains indicate the dns query domains to use with this nameserver group
Domains []string
// Enabled group status
Enabled bool
}
// NameServer represents a DNS nameserver
type NameServer struct {
// IP address of nameserver
IP netip.Addr
// NSType nameserver type
NSType NameServerType
// Port nameserver listening port
Port int
}
// Copy copies a nameserver object
func (n *NameServer) Copy() *NameServer {
return &NameServer{
IP: n.IP,
NSType: n.NSType,
Port: n.Port,
}
}
// IsEqual compares one nameserver with the other
func (n *NameServer) IsEqual(other *NameServer) bool {
return other.IP == n.IP &&
other.NSType == n.NSType &&
other.Port == n.Port
}
// ParseNameServerURL parses a nameserver url in the format <type>://<ip>:<port>, e.g., udp://1.1.1.1:53
func ParseNameServerURL(nsURL string) (NameServer, error) {
parsedURL, err := url.Parse(nsURL)
if err != nil {
return NameServer{}, err
}
var ns NameServer
parsedScheme := strings.ToLower(parsedURL.Scheme)
nsType := ToNameServerType(parsedScheme)
if nsType == InvalidNameServerType {
return NameServer{}, fmt.Errorf("invalid nameserver url schema type, got %s", parsedScheme)
}
ns.NSType = nsType
parsedPort, err := strconv.Atoi(parsedURL.Port())
if err != nil {
return NameServer{}, fmt.Errorf("invalid nameserver url port, got %s", parsedURL.Port())
}
ns.Port = parsedPort
parsedAddr, err := netip.ParseAddr(parsedURL.Hostname())
if err != nil {
return NameServer{}, fmt.Errorf("invalid nameserver url IP, got %s", parsedURL.Hostname())
}
ns.IP = parsedAddr
return ns, nil
}
// Copy copies a nameserver group object
func (g *NameServerGroup) Copy() *NameServerGroup {
return &NameServerGroup{
ID: g.ID,
Name: g.Name,
Description: g.Description,
NameServers: g.NameServers,
Groups: g.Groups,
Enabled: g.Enabled,
Primary: g.Primary,
Domains: g.Domains,
}
}
// IsEqual compares one nameserver group with the other
func (g *NameServerGroup) IsEqual(other *NameServerGroup) bool {
return other.ID == g.ID &&
other.Name == g.Name &&
other.Description == g.Description &&
other.Primary == g.Primary &&
compareNameServerList(g.NameServers, other.NameServers) &&
compareGroupsList(g.Groups, other.Groups) &&
compareGroupsList(g.Domains, other.Domains)
}
func compareNameServerList(list, other []NameServer) bool {
if len(list) != len(other) {
return false
}
for _, ns := range list {
if !containsNameServer(ns, other) {
return false
}
}
return true
}
func containsNameServer(element NameServer, list []NameServer) bool {
for _, ns := range list {
if ns.IsEqual(&element) {
return true
}
}
return false
}
func compareGroupsList(list, other []string) bool {
if len(list) != len(other) {
return false
}
for _, id := range list {
match := false
for _, otherID := range other {
if id == otherID {
match = true
break
}
}
if !match {
return false
}
}
return true
}

51
go.mod
View File

@@ -11,19 +11,19 @@ require (
github.com/kardianos/service v1.2.1-0.20210728001519-a323c3813bc7 //keep this version otherwise wiretrustee up command breaks
github.com/onsi/ginkgo v1.16.5
github.com/onsi/gomega v1.18.1
github.com/pion/ice/v2 v2.2.7
github.com/pion/ice/v2 v2.1.17
github.com/rs/cors v1.8.0
github.com/sirupsen/logrus v1.8.1
github.com/spf13/cobra v1.3.0
github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.1.0
golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9
golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664
golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de
golang.zx2c4.com/wireguard/windows v0.5.1
google.golang.org/grpc v1.43.0
google.golang.org/protobuf v1.28.1
google.golang.org/protobuf v1.28.0
gopkg.in/natefinch/lumberjack.v2 v2.0.0
)
@@ -32,22 +32,21 @@ require (
github.com/c-robinson/iplib v1.0.3
github.com/coreos/go-iptables v0.6.0
github.com/creack/pty v1.1.18
github.com/eko/gocache/v3 v3.1.1
github.com/eko/gocache/v2 v2.3.1
github.com/getlantern/systray v1.2.1
github.com/gliderlabs/ssh v0.3.4
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
github.com/libp2p/go-netroute v0.2.0
github.com/magiconair/properties v1.8.5
github.com/miekg/dns v1.1.41
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/prometheus/client_golang v1.13.0
github.com/pion/logging v0.2.2
github.com/pion/stun v0.3.5
github.com/pion/transport v0.13.0
github.com/rs/xid v1.3.0
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/stretchr/testify v1.8.0
go.opentelemetry.io/otel/exporters/prometheus v0.33.0
go.opentelemetry.io/otel/metric v0.33.0
go.opentelemetry.io/otel/sdk/metric v0.33.0
golang.org/x/net v0.0.0-20220630215102-69896b714898
github.com/stretchr/testify v1.7.1
go.uber.org/zap v1.17.0
golang.org/x/net v0.0.0-20220513224357-95641704303c
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467
)
@@ -70,13 +69,11 @@ require (
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f // indirect
github.com/go-gl/gl v0.0.0-20210813123233-e4099ee2221f // indirect
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211024062804-40e447a793be // indirect
github.com/go-logr/logr v1.2.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-redis/redis/v8 v8.11.5 // indirect
github.com/go-stack/stack v1.8.0 // indirect
github.com/godbus/dbus/v5 v5.0.4 // indirect
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/google/go-cmp v0.5.7 // indirect
github.com/google/gopacket v1.1.19 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 // indirect
@@ -87,31 +84,27 @@ require (
github.com/nxadm/tail v1.4.8 // indirect
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
github.com/pegasus-kv/thrift v0.13.0 // indirect
github.com/pion/dtls/v2 v2.1.5 // indirect
github.com/pion/logging v0.2.2 // indirect
github.com/pion/dtls/v2 v2.1.2 // indirect
github.com/pion/mdns v0.0.5 // indirect
github.com/pion/randutil v0.1.0 // indirect
github.com/pion/stun v0.3.5 // indirect
github.com/pion/transport v0.13.1 // indirect
github.com/pion/turn/v2 v2.0.8 // indirect
github.com/pion/turn/v2 v2.0.7 // indirect
github.com/pion/udp v0.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_golang v1.12.2 // indirect
github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/common v0.37.0 // indirect
github.com/prometheus/procfs v0.8.0 // indirect
github.com/prometheus/common v0.33.0 // indirect
github.com/prometheus/procfs v0.7.3 // indirect
github.com/rogpeppe/go-internal v1.8.0 // indirect
github.com/spf13/cast v1.5.0 // indirect
github.com/srwiley/oksvg v0.0.0-20200311192757-870daf9aa564 // indirect
github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect
github.com/yuin/goldmark v1.4.1 // indirect
go.opentelemetry.io/otel v1.11.1 // indirect
go.opentelemetry.io/otel/sdk v1.11.1 // indirect
go.opentelemetry.io/otel/trace v1.11.1 // indirect
golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf // indirect
go.uber.org/atomic v1.7.0 // indirect
go.uber.org/multierr v1.6.0 // indirect
golang.org/x/image v0.0.0-20200430140353-33d19683fad8 // indirect
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f // indirect
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
golang.org/x/text v0.3.8-0.20211105212822-18b340fc7af2 // indirect
golang.org/x/tools v0.1.10 // indirect
golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f // indirect
@@ -122,11 +115,11 @@ require (
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
honnef.co/go/tools v0.2.2 // indirect
k8s.io/apimachinery v0.23.5 // indirect
)
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20220905002524-6ac14ad5ea84
replace github.com/pion/ice/v2 => github.com/wiretrustee/ice/v2 v2.1.21-0.20220218121004-dc81faead4bb
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20221012095658-dc8eda872c0c
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20220905002524-6ac14ad5ea84

89
go.sum
View File

@@ -134,8 +134,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cu
github.com/docker/spdystream v0.0.0-20160310174837-449fdfce4d96/go.mod h1:Qh8CwZgvJUkLughtfhJv5dyTYa91l1fOUCrgjqmcifM=
github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE=
github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
github.com/eko/gocache/v3 v3.1.1 h1:r3CBwLnqPkcK56h9Do2CWw1kZ4TeKK0wDE1Oo/YZnhs=
github.com/eko/gocache/v3 v3.1.1/go.mod h1:UpP/LyHAioP/a/dizgl0MpgZ3A3CkS4NbG/mWkGTQ9M=
github.com/eko/gocache/v2 v2.3.1 h1:8MMkfqGJ0KIA9OXT0rXevcEIrU16oghrGDiIDJDFCa0=
github.com/eko/gocache/v2 v2.3.1/go.mod h1:l2z8OmpZHL0CpuzDJtxm267eF3mZW1NqUsMj+sKrbUs=
github.com/elazarl/goproxy v0.0.0-20170405201442-c4fc26588b6e/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc=
github.com/elazarl/goproxy v0.0.0-20180725130230-947c36da3153/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc=
github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs=
@@ -178,6 +178,8 @@ github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55 h1:XYzSdCbkzOC0F
github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55/go.mod h1:6mmzY2kW1TOOrVy+r41Za2MxXM+hhqTtY3oBKd2AgFA=
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f h1:wrYrQttPS8FHIRSlsrcuKazukx/xqO/PpLZzZXsF+EA=
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f/go.mod h1:D5ao98qkA6pxftxoqzibIBBrLSUli+kYnJqrgBf9cIA=
github.com/getlantern/systray v1.2.1 h1:udsC2k98v2hN359VTFShuQW6GGprRprw6kD6539JikI=
github.com/getlantern/systray v1.2.1/go.mod h1:AecygODWIsBquJCJFop8MEQcJbWFfw/1yWbVabNgpCM=
github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
@@ -202,11 +204,6 @@ github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KE
github.com/go-logr/logr v0.1.0/go.mod h1:ixOQHD9gLJUVQQ2ZOR7zLEifBX6tGkNJF4QyIY7sIas=
github.com/go-logr/logr v0.2.0/go.mod h1:z6/tIYblkpsD+a4lm/fGIIU9mZ+XfAiaFtq7xTgseGU=
github.com/go-logr/logr v1.2.0/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-openapi/jsonpointer v0.0.0-20160704185906-46af16f9f7b1/go.mod h1:+35s3my2LFTysnkMfxsJBAMHj/DoqoB9knIWoYG/Vk0=
github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
@@ -283,8 +280,8 @@ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o=
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@@ -455,7 +452,6 @@ github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb h1:2dC7L10LmTqlyMV
github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb/go.mod h1:nFZ1EtZYK8Gi/k6QNu7z7CgO20i/4ExeQswwWuPmG/g=
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso=
github.com/miekg/dns v1.1.41 h1:WMszZWJG0XmzbK9FEmzH2TVcqYzFesusSIB41b8KHxY=
github.com/miekg/dns v1.1.41/go.mod h1:p6aan82bvRIyn+zDIv9xYNUpwa73JcSh9BKwknJysuI=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
@@ -478,8 +474,6 @@ github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRW
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
github.com/netbirdio/service v0.0.0-20220905002524-6ac14ad5ea84 h1:u8kpzR9ld1uAeH/BAXsS0SfcnhooLWeO7UgHSBVPD9I=
github.com/netbirdio/service v0.0.0-20220905002524-6ac14ad5ea84/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/systray v0.0.0-20221012095658-dc8eda872c0c h1:wK/s4nyZj/GF/kFJQjX6nqNfE0G3gcqd6hhnPCyp4sw=
github.com/netbirdio/systray v0.0.0-20221012095658-dc8eda872c0c/go.mod h1:AecygODWIsBquJCJFop8MEQcJbWFfw/1yWbVabNgpCM=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
@@ -511,10 +505,8 @@ github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTK
github.com/pegasus-kv/thrift v0.13.0 h1:4ESwaNoHImfbHa9RUGJiJZ4hrxorihZHk5aarYwY8d4=
github.com/pegasus-kv/thrift v0.13.0/go.mod h1:Gl9NT/WHG6ABm6NsrbfE8LiJN0sAyneCrvB4qN4NPqQ=
github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c=
github.com/pion/dtls/v2 v2.1.5 h1:jlh2vtIyUBShchoTDqpCCqiYCyRFJ/lvf/gQ8TALs+c=
github.com/pion/dtls/v2 v2.1.5/go.mod h1:BqCE7xPZbPSubGasRoDFJeTsyJtdD1FanJYL0JGheqY=
github.com/pion/ice/v2 v2.2.7 h1:kG9tux3WdYUSqqqnf+O5zKlpy41PdlvLUBlYJeV2emQ=
github.com/pion/ice/v2 v2.2.7/go.mod h1:Ckj7cWZ717rtU01YoDQA9ntGWCk95D42uVZ8sI0EL+8=
github.com/pion/dtls/v2 v2.1.2 h1:22Q1Jk9L++Yo7BIf9130MonNPfPVb+YgdYLeyQotuAA=
github.com/pion/dtls/v2 v2.1.2/go.mod h1:o6+WvyLDAlXF7YiPB/RlskRoeK+/JtuaZa5emwQcWus=
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
github.com/pion/mdns v0.0.5 h1:Q2oj/JB3NqfzY9xGZ1fPzZzK7sDSD8rZPOvcIQ10BCw=
@@ -524,11 +516,10 @@ github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TB
github.com/pion/stun v0.3.5 h1:uLUCBCkQby4S1cf6CGuR9QrVOKcvUwFeemaC865QHDg=
github.com/pion/stun v0.3.5/go.mod h1:gDMim+47EeEtfWogA37n6qXZS88L5V6LqFcf+DZA2UA=
github.com/pion/transport v0.12.2/go.mod h1:N3+vZQD9HlDP5GWkZ85LohxNsDcNgofQmyL6ojX5d8Q=
github.com/pion/transport v0.13.0 h1:KWTA5ZrQogizzYwPEciGtHPLwpAjE91FgXnyu+Hv2uY=
github.com/pion/transport v0.13.0/go.mod h1:yxm9uXpK9bpBBWkITk13cLo1y5/ur5VQpG22ny6EP7g=
github.com/pion/transport v0.13.1 h1:/UH5yLeQtwm2VZIPjxwnNFxjS4DFhyLfS4GlfuKUzfA=
github.com/pion/transport v0.13.1/go.mod h1:EBxbqzyv+ZrmDb82XswEE0BjfQFtuw1Nu6sjnjWCsGg=
github.com/pion/turn/v2 v2.0.8 h1:KEstL92OUN3k5k8qxsXHpr7WWfrdp7iJZHx99ud8muw=
github.com/pion/turn/v2 v2.0.8/go.mod h1:+y7xl719J8bAEVpSXBXvTxStjJv3hbz9YFflvkpcGPw=
github.com/pion/turn/v2 v2.0.7 h1:SZhc00WDovK6czaN1RSiHqbwANtIO6wfZQsU0m0KNE8=
github.com/pion/turn/v2 v2.0.7/go.mod h1:+y7xl719J8bAEVpSXBXvTxStjJv3hbz9YFflvkpcGPw=
github.com/pion/udp v0.1.1 h1:8UAPvyqmsxK8oOjloDk4wUt63TzFe9WEJkg5lChlj7o=
github.com/pion/udp v0.1.1/go.mod h1:6AFo+CMdKQm7UiA0eUPA8/eVCTx8jBIITLZHc9DWX5M=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
@@ -548,8 +539,8 @@ github.com/prometheus/client_golang v1.4.0/go.mod h1:e9GMxYsXl05ICDXkRhurwBS4Q3O
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY=
github.com/prometheus/client_golang v1.13.0 h1:b71QUfeo5M8gq2+evJdTPfZhYMAU0uKPkyPJ7TPsloU=
github.com/prometheus/client_golang v1.13.0/go.mod h1:vTeo+zgvILHsnnj/39Ou/1fPN5nJFOEMgftOUOmlvYQ=
github.com/prometheus/client_golang v1.12.2 h1:51L9cDoUHVrXx4zWYlcLQIZ+d+VXHgqnYKkIuq4g/34=
github.com/prometheus/client_golang v1.12.2/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
@@ -560,16 +551,15 @@ github.com/prometheus/common v0.9.1/go.mod h1:yhUN8i9wzaXS3w1O07YhxHEBxD+W35wd8b
github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
github.com/prometheus/common v0.32.1/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls=
github.com/prometheus/common v0.37.0 h1:ccBbHCgIiT9uSoFY0vX8H3zsNR5eLt17/RQLUvn8pXE=
github.com/prometheus/common v0.37.0/go.mod h1:phzohg0JFMnBEFGxTDbfu3QyL5GI8gTQJFhYO5B3mfA=
github.com/prometheus/common v0.33.0 h1:rHgav/0a6+uYgGdNt3jwz8FNSesO/Hsang3O0T9A5SE=
github.com/prometheus/common v0.33.0/go.mod h1:gB3sOl7P0TvJabZpLY5uQMpUqRCPPCyRLCZYc7JZTNE=
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A=
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/prometheus/procfs v0.7.3 h1:4jVXhlkAyzOScmCkXBTOLRLTz8EeU+eyjrwB/EPq0VU=
github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo=
github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
@@ -616,7 +606,6 @@ github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9/go.mod h1:mvWM0+15
github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
@@ -624,9 +613,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM=
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
@@ -636,6 +624,8 @@ github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJ
github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE=
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k=
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
github.com/wiretrustee/ice/v2 v2.1.21-0.20220218121004-dc81faead4bb h1:CU1/+CEeCPvYXgfAyqTJXSQSf6hW3wsWM6Dfz6HkHEQ=
github.com/wiretrustee/ice/v2 v2.1.21-0.20220218121004-dc81faead4bb/go.mod h1:XT1Nrb4OxbVFPffbQMbq4PaeEkpRLVzdphh3fjrw7DY=
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@@ -655,21 +645,12 @@ go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
go.opentelemetry.io/otel v1.11.1 h1:4WLLAmcfkmDk2ukNXJyq3/kiz/3UzCaYq6PskJsaou4=
go.opentelemetry.io/otel v1.11.1/go.mod h1:1nNhXBbWSD0nsL38H6btgnFN2k4i0sNLHNNMZMSbUGE=
go.opentelemetry.io/otel/exporters/prometheus v0.33.0 h1:xXhPj7SLKWU5/Zd4Hxmd+X1C4jdmvc0Xy+kvjFx2z60=
go.opentelemetry.io/otel/exporters/prometheus v0.33.0/go.mod h1:ZSmYfKdYWEdSDBB4njLBIwTf4AU2JNsH3n2quVQDebI=
go.opentelemetry.io/otel/metric v0.33.0 h1:xQAyl7uGEYvrLAiV/09iTJlp1pZnQ9Wl793qbVvED1E=
go.opentelemetry.io/otel/metric v0.33.0/go.mod h1:QlTYc+EnYNq/M2mNk1qDDMRLpqCOj2f/r5c7Fd5FYaI=
go.opentelemetry.io/otel/sdk v1.11.1 h1:F7KmQgoHljhUuJyA+9BiU+EkJfyX5nVVF4wyzWZpKxs=
go.opentelemetry.io/otel/sdk v1.11.1/go.mod h1:/l3FE4SupHJ12TduVjUkZtlfFqDCQJlOlithYrdktys=
go.opentelemetry.io/otel/sdk/metric v0.33.0 h1:oTqyWfksgKoJmbrs2q7O7ahkJzt+Ipekihf8vhpa9qo=
go.opentelemetry.io/otel/sdk/metric v0.33.0/go.mod h1:xdypMeA21JBOvjjzDUtD0kzIcHO/SPez+a8HOzJPGp0=
go.opentelemetry.io/otel/trace v1.11.1 h1:ofxdnzsNrGBYXbP7t7zpUK281+go5rF7dvdIZXF8gdQ=
go.opentelemetry.io/otel/trace v1.11.1/go.mod h1:f/Q9G7vzk5u91PhbmKbg1Qn0rzH1LJ4vbPHFGkTPtOk=
go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI=
go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4=
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
go.uber.org/zap v1.17.0 h1:MTjgFu6ZLKvY6Pvaqk97GlxNBuMpV4Hy/3P6tRGlI2U=
go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
@@ -684,7 +665,7 @@ golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20211202192323-5770296d904e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220131195533-30dcbda58838/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9 h1:NUzdAbFtCJSXU20AOXgeqaUwg8Ypg4MPYmL+d+rsB5c=
golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@@ -697,8 +678,6 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf h1:oXVg4h2qJDd9htKxb5SCpFBHLipW6hXmL3qpUixS2jw=
golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf/go.mod h1:yh0Ynu2b5ZUe3MQfp2nM0ecK7wsgouWTDN0FNeJuIys=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/image v0.0.0-20200430140353-33d19683fad8 h1:6WW6V3x1P/jokJBpRQYUJnMHRP6isStQwCozxnU7XQw=
@@ -796,10 +775,8 @@ golang.org/x/net v0.0.0-20211208012354-db4efeb81f4b/go.mod h1:9nx3DQGgdP8bBQD5qx
golang.org/x/net v0.0.0-20211209124913-491a49abca63/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220531201128-c960675eff93/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.0.0-20220630215102-69896b714898 h1:K7wO6V1IrczY9QOQ2WkVpw4JQSwCd52UsxVEirZUfiw=
golang.org/x/net v0.0.0-20220630215102-69896b714898/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.0.0-20220513224357-95641704303c h1:nF9mHSvoKBLkQNQhJZNsc66z2UzAMUbLGjC95CF3pU0=
golang.org/x/net v0.0.0-20220513224357-95641704303c/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@@ -828,9 +805,8 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f h1:Ax0t5p6N38Ga0dThY21weqDEyz2oklo4IvDkpigvkD8=
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -933,10 +909,8 @@ golang.org/x/sys v0.0.0-20211205182925-97ca703d548d/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20211214234402-4825e8c3871d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220608164250-635b8c9b7f68/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8 h1:h+EGohizhe9XlX18rfpa8k8RAc5XyaeamM+0VHRd4lc=
golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664 h1:wEZYwx+kK+KlZ0hpvP2Ls1Xr4+RWnlzGFwPP0aiDjIU=
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467 h1:CBpWXWQpIRjzmkkA+M7q9Fqnwd2mZr3AFqexg8YTfoM=
@@ -1182,8 +1156,8 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w=
google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
@@ -1214,9 +1188,8 @@ gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

185
iface/bind.go Normal file
View File

@@ -0,0 +1,185 @@
package iface
import (
"errors"
"fmt"
"github.com/pion/stun"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
"net"
"net/netip"
"sync"
"syscall"
)
type ICEBind struct {
sharedConn net.PacketConn
udpMux *UniversalUDPMuxDefault
iceHostMux *UDPMuxDefault
mu sync.Mutex // protects following fields
}
func (b *ICEBind) GetICEMux() (UniversalUDPMux, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.udpMux == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
return b.udpMux, nil
}
func (b *ICEBind) GetICEHostMux() (UDPMux, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.iceHostMux == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
return b.iceHostMux, nil
}
func (b *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.sharedConn != nil {
return nil, 0, conn.ErrBindAlreadyOpen
}
port := int(uport)
ipv4Conn, port, err := listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
b.sharedConn = ipv4Conn
b.udpMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: b.sharedConn})
portAddr1, err := netip.ParseAddrPort(ipv4Conn.LocalAddr().String())
if err != nil {
return nil, 0, err
}
log.Infof("opened ICEBind on %s", ipv4Conn.LocalAddr().String())
return []conn.ReceiveFunc{
b.makeReceiveIPv4(b.sharedConn),
},
portAddr1.Port(), nil
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
if err != nil {
return nil, 0, err
}
// Retrieve port.
laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr(
laddr.Network(),
laddr.String(),
)
if err != nil {
return nil, 0, err
}
return conn, uaddr.Port, nil
}
func parseSTUNMessage(raw []byte) (*stun.Message, error) {
msg := &stun.Message{
Raw: append([]byte{}, raw...),
}
if err := msg.Decode(); err != nil {
return nil, err
}
return msg, nil
}
func (b *ICEBind) makeReceiveIPv4(c net.PacketConn) conn.ReceiveFunc {
return func(buff []byte) (int, conn.Endpoint, error) {
n, endpoint, err := c.ReadFrom(buff)
if err != nil {
return 0, nil, err
}
e, err := netip.ParseAddrPort(endpoint.String())
if err != nil {
return 0, nil, err
}
if !stun.IsMessage(buff[:20]) {
// WireGuard traffic
return n, (*conn.StdNetEndpoint)(&net.UDPAddr{
IP: e.Addr().AsSlice(),
Port: int(e.Port()),
Zone: e.Addr().Zone(),
}), nil
}
msg, err := parseSTUNMessage(buff[:n])
if err != nil {
return 0, nil, err
}
err = b.udpMux.HandleSTUNMessage(msg, endpoint)
if err != nil {
return 0, nil, err
}
if err != nil {
log.Warnf("failed to handle packet")
}
// discard packets because they are STUN related
return 0, nil, nil //todo proper return
}
}
func (b *ICEBind) Close() error {
b.mu.Lock()
defer b.mu.Unlock()
var err1, err2 error
if b.sharedConn != nil {
c := b.sharedConn
b.sharedConn = nil
err1 = c.Close()
}
if b.udpMux != nil {
m := b.udpMux
b.udpMux = nil
err2 = m.Close()
}
if err1 != nil {
return err1
}
return err2
}
// SetMark sets the mark for each packet sent through this Bind.
// This mark is passed to the kernel as the socket option SO_MARK.
func (b *ICEBind) SetMark(mark uint32) error {
return nil
}
func (b *ICEBind) Send(buff []byte, endpoint conn.Endpoint) error {
nend, ok := endpoint.(*conn.StdNetEndpoint)
if !ok {
return conn.ErrWrongEndpointType
}
_, err := b.sharedConn.WriteTo(buff, (*net.UDPAddr)(nend))
return err
}
// ParseEndpoint creates a new endpoint from a string.
func (b *ICEBind) ParseEndpoint(s string) (ep conn.Endpoint, err error) {
e, err := netip.ParseAddrPort(s)
return (*conn.StdNetEndpoint)(&net.UDPAddr{
IP: e.Addr().AsSlice(),
Port: int(e.Port()),
Zone: e.Addr().Zone(),
}), err
}

View File

@@ -55,7 +55,6 @@ func (w *WGIface) Configure(privateKey string, port int) error {
PrivateKey: &key,
ReplacePeers: true,
FirewallMark: &fwmark,
ListenPort: &port,
}
err = w.configureDevice(config)

View File

@@ -2,6 +2,10 @@ package iface
import (
"fmt"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"net"
"os"
"runtime"
@@ -21,6 +25,7 @@ type WGIface struct {
Address WGAddress
Interface NetInterface
mu sync.Mutex
Bind *ICEBind
}
// WGAddress Wireguard parsed address
@@ -91,3 +96,49 @@ func (w *WGIface) Close() error {
return nil
}
func (w *WGIface) CreateNew(bind conn.Bind) error {
w.mu.Lock()
defer w.mu.Unlock()
return w.createWithUserspaceNew(bind)
}
func (w *WGIface) createWithUserspaceNew(bind conn.Bind) error {
tunIface, err := tun.CreateTUN(w.Name, w.MTU)
if err != nil {
return err
}
w.Interface = tunIface
// We need to create a wireguard-go device and listen to configuration requests
tunDevice := device.NewDevice(tunIface, bind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
err = tunDevice.Up()
if err != nil {
return err
}
uapi, err := getUAPI(w.Name)
if err != nil {
return err
}
go func() {
for {
uapiConn, uapiErr := uapi.Accept()
if uapiErr != nil {
log.Traceln("uapi Accept failed with error: ", uapiErr)
continue
}
go tunDevice.IpcHandle(uapiConn)
}
}()
log.Debugln("UAPI listener started")
err = w.assignAddr()
if err != nil {
return err
}
return nil
}

View File

@@ -34,7 +34,7 @@ func (w *WGIface) assignAddr() error {
return nil
}
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only)
func WireguardModuleIsLoaded() bool {
// WireguardModExists check if we can load wireguard mod (linux only)
func WireguardModExists() bool {
return false
}

View File

@@ -1,32 +1,45 @@
package iface
import (
"fmt"
"errors"
"math"
"os"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"os"
)
type NativeLink struct {
Link *netlink.Link
}
// WireguardModExists check if we can load wireguard mod (linux only)
func WireguardModExists() bool {
link := newWGLink("mustnotexist")
// We willingly try to create a device with an invalid
// MTU here as the validation of the MTU will be performed after
// the validation of the link kind and hence allows us to check
// for the existance of the wireguard module without actually
// creating a link.
//
// As a side-effect, this will also let the kernel lazy-load
// the wireguard module.
link.attrs.MTU = math.MaxInt
err := netlink.LinkAdd(link)
return errors.Is(err, syscall.EINVAL)
}
// Create creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
func (w *WGIface) Create() error {
w.mu.Lock()
defer w.mu.Unlock()
if WireguardModuleIsLoaded() {
log.Info("using kernel WireGuard")
return w.createWithKernel()
} else {
if !tunModuleIsLoaded() {
return fmt.Errorf("couldn't check or load tun module")
}
log.Info("using userspace WireGuard")
return w.createWithUserspace()
}
return w.createWithUserspace()
}
// createWithKernel Creates a new Wireguard interface using kernel Wireguard module.

View File

@@ -89,6 +89,7 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
return addrs, nil
}
//
func Test_CreateInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
wgIP := "10.99.99.1/32"
@@ -368,8 +369,8 @@ func Test_ConnectPeers(t *testing.T) {
if err != nil {
t.Fatal(err)
}
// todo: investigate why in some tests execution we need 30s
timeout := 30 * time.Second
timeout := 10 * time.Second
timeoutChannel := time.After(timeout)
for {
select {

View File

@@ -4,6 +4,7 @@ import (
"fmt"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/windows/driver"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"net"
@@ -58,7 +59,12 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
return w.assignAddr(luid)
}
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only)
func WireguardModuleIsLoaded() bool {
// WireguardModExists check if we can load wireguard mod (linux only)
func WireguardModExists() bool {
return false
}
// getUAPI returns a Listener
func getUAPI(iface string) (net.Listener, error) {
return ipc.UAPIListen(iface)
}

View File

@@ -1,350 +0,0 @@
// Package iface provides wireguard network interface creation and management
package iface
import (
"bufio"
"errors"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
"io/fs"
"io/ioutil"
"math"
"os"
"path/filepath"
"strings"
"syscall"
)
// Holds logic to check existence of kernel modules used by wireguard interfaces
// Copied from https://github.com/paultag/go-modprobe and
// https://github.com/pmorjan/kmod
type status int
const (
defaultModuleDir = "/lib/modules"
unknown status = iota
unloaded
unloading
loading
live
inuse
)
type module struct {
name string
path string
}
var (
// ErrModuleNotFound is the error resulting if a module can't be found.
ErrModuleNotFound = errors.New("module not found")
moduleLibDir = defaultModuleDir
// get the root directory for the kernel modules. If this line panics,
// it's because getModuleRoot has failed to get the uname of the running
// kernel (likely a non-POSIX system, but maybe a broken kernel?)
moduleRoot = getModuleRoot()
)
// Get the module root (/lib/modules/$(uname -r)/)
func getModuleRoot() string {
uname := unix.Utsname{}
if err := unix.Uname(&uname); err != nil {
panic(err)
}
i := 0
for ; uname.Release[i] != 0; i++ {
}
return filepath.Join(moduleLibDir, string(uname.Release[:i]))
}
// tunModuleIsLoaded check if tun module exist, if is not attempt to load it
func tunModuleIsLoaded() bool {
_, err := os.Stat("/dev/net/tun")
if err == nil {
return true
}
log.Infof("couldn't access device /dev/net/tun, go error %v, "+
"will attempt to load tun module, if running on container add flag --cap-add=NET_ADMIN", err)
tunLoaded, err := tryToLoadModule("tun")
if err != nil {
log.Errorf("unable to find or load tun module, got error: %v", err)
}
return tunLoaded
}
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only)
func WireguardModuleIsLoaded() bool {
if canCreateFakeWireguardInterface() {
return true
}
loaded, err := tryToLoadModule("wireguard")
if err != nil {
log.Info(err)
return false
}
return loaded
}
func canCreateFakeWireguardInterface() bool {
link := newWGLink("mustnotexist")
// We willingly try to create a device with an invalid
// MTU here as the validation of the MTU will be performed after
// the validation of the link kind and hence allows us to check
// for the existance of the wireguard module without actually
// creating a link.
//
// As a side-effect, this will also let the kernel lazy-load
// the wireguard module.
link.attrs.MTU = math.MaxInt
err := netlink.LinkAdd(link)
return errors.Is(err, syscall.EINVAL)
}
func tryToLoadModule(moduleName string) (bool, error) {
if isModuleEnabled(moduleName) {
return true, nil
}
modulePath, err := getModulePath(moduleName)
if err != nil {
return false, fmt.Errorf("couldn't find module path for %s, error: %v", moduleName, err)
}
if modulePath == "" {
return false, nil
}
log.Infof("trying to load %s module", moduleName)
err = loadModuleWithDependencies(moduleName, modulePath)
if err != nil {
return false, fmt.Errorf("couldn't load %s module, error: %v", moduleName, err)
}
return true, nil
}
func isModuleEnabled(name string) bool {
builtin, builtinErr := isBuiltinModule(name)
state, statusErr := moduleStatus(name)
return (builtinErr == nil && builtin) || (statusErr == nil && state >= loading)
}
func getModulePath(name string) (string, error) {
var foundPath string
skipRemainingDirs := false
err := filepath.WalkDir(
moduleRoot,
func(path string, info fs.DirEntry, err error) error {
if skipRemainingDirs {
return fs.SkipDir
}
if err != nil {
// skip broken files
return nil
}
if !info.Type().IsRegular() {
return nil
}
nameFromPath := pathToName(path)
if nameFromPath == name {
foundPath = path
skipRemainingDirs = true
}
return nil
})
if err != nil {
return "", err
}
return foundPath, nil
}
func pathToName(s string) string {
s = filepath.Base(s)
for ext := filepath.Ext(s); ext != ""; ext = filepath.Ext(s) {
s = strings.TrimSuffix(s, ext)
}
return cleanName(s)
}
func cleanName(s string) string {
return strings.ReplaceAll(strings.TrimSpace(s), "-", "_")
}
func isBuiltinModule(name string) (bool, error) {
f, err := os.Open(filepath.Join(moduleRoot, "/modules.builtin"))
if err != nil {
return false, err
}
defer func() {
err := f.Close()
if err != nil {
log.Errorf("failed closing modules.builtin file, %v", err)
}
}()
var found bool
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
if pathToName(line) == name {
found = true
break
}
}
if err := scanner.Err(); err != nil {
return false, err
}
return found, nil
}
// /proc/modules
//
// name | memory size | reference count | references | state: <Live|Loading|Unloading>
// macvlan 28672 1 macvtap, Live 0x0000000000000000
func moduleStatus(name string) (status, error) {
state := unknown
f, err := os.Open("/proc/modules")
if err != nil {
return state, err
}
defer func() {
err := f.Close()
if err != nil {
log.Errorf("failed closing /proc/modules file, %v", err)
}
}()
state = unloaded
scanner := bufio.NewScanner(f)
for scanner.Scan() {
fields := strings.Fields(scanner.Text())
if fields[0] == name {
if fields[2] != "0" {
state = inuse
break
}
switch fields[4] {
case "Live":
state = live
case "Loading":
state = loading
case "Unloading":
state = unloading
}
break
}
}
if err := scanner.Err(); err != nil {
return state, err
}
return state, nil
}
func loadModuleWithDependencies(name, path string) error {
deps, err := getModuleDependencies(name)
if err != nil {
return fmt.Errorf("couldn't load list of module %s dependecies", name)
}
for _, dep := range deps {
err = loadModule(dep.name, dep.path)
if err != nil {
return fmt.Errorf("couldn't load dependecy module %s for %s", dep.name, name)
}
}
return loadModule(name, path)
}
func loadModule(name, path string) error {
state, err := moduleStatus(name)
if err != nil {
return err
}
if state >= loading {
return nil
}
f, err := os.Open(path)
if err != nil {
return err
}
defer func() {
err := f.Close()
if err != nil {
log.Errorf("failed closing %s file, %v", path, err)
}
}()
// first try finit_module(2), then init_module(2)
err = unix.FinitModule(int(f.Fd()), "", 0)
if errors.Is(err, unix.ENOSYS) {
buf, err := ioutil.ReadAll(f)
if err != nil {
return err
}
return unix.InitModule(buf, "")
}
return err
}
// getModuleDependencies returns a module dependencies
func getModuleDependencies(name string) ([]module, error) {
f, err := os.Open(filepath.Join(moduleRoot, "/modules.dep"))
if err != nil {
return nil, err
}
defer func() {
err := f.Close()
if err != nil {
log.Errorf("failed closing modules.dep file, %v", err)
}
}()
var deps []string
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
fields := strings.Fields(line)
if pathToName(strings.TrimSuffix(fields[0], ":")) == name {
deps = fields
break
}
}
if err := scanner.Err(); err != nil {
return nil, err
}
if len(deps) == 0 {
return nil, ErrModuleNotFound
}
deps[0] = strings.TrimSuffix(deps[0], ":")
var modules []module
for _, v := range deps {
if pathToName(v) != name {
modules = append(modules, module{
name: pathToName(v),
path: filepath.Join(moduleRoot, v),
})
}
}
return modules, nil
}

View File

@@ -1,221 +0,0 @@
package iface
import (
"bufio"
"bytes"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
"io"
"io/ioutil"
"os"
"path/filepath"
"strings"
"testing"
)
func TestGetModuleDependencies(t *testing.T) {
testCases := []struct {
name string
module string
expected []module
}{
{
name: "Get Single Dependency",
module: "bar",
expected: []module{
{name: "foo", path: "kernel/a/foo.ko"},
},
},
{
name: "Get Multiple Dependencies",
module: "baz",
expected: []module{
{name: "foo", path: "kernel/a/foo.ko"},
{name: "bar", path: "kernel/a/bar.ko"},
},
},
{
name: "Get No Dependencies",
module: "foo",
expected: []module{},
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
defer resetGlobals()
_, _ = createFiles(t)
modules, err := getModuleDependencies(testCase.module)
require.NoError(t, err)
expected := testCase.expected
for i := range expected {
expected[i].path = moduleRoot + "/" + expected[i].path
}
require.ElementsMatchf(t, modules, expected, "returned modules should match")
})
}
}
func TestIsBuiltinModule(t *testing.T) {
testCases := []struct {
name string
module string
expected bool
}{
{
name: "Built In Should Return True",
module: "foo_bi",
expected: true,
},
{
name: "Not Built In Should Return False",
module: "not_built_in",
expected: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
defer resetGlobals()
_, _ = createFiles(t)
isBuiltIn, err := isBuiltinModule(testCase.module)
require.NoError(t, err)
require.Equal(t, testCase.expected, isBuiltIn)
})
}
}
func TestModuleStatus(t *testing.T) {
random, err := getRandomLoadedModule(t)
if err != nil {
t.Fatal("should be able to get random module")
}
testCases := []struct {
name string
module string
shouldBeLoaded bool
}{
{
name: "Should Return Module Loading Or Greater Status",
module: random,
shouldBeLoaded: true,
},
{
name: "Should Return Module Unloaded Or Lower Status",
module: "not_loaded_module",
shouldBeLoaded: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
defer resetGlobals()
_, _ = createFiles(t)
state, err := moduleStatus(testCase.module)
require.NoError(t, err)
if testCase.shouldBeLoaded {
require.GreaterOrEqual(t, loading, state, "moduleStatus for %s should return state loading", testCase.module)
} else {
require.Less(t, state, loading, "module should return state unloading or lower")
}
})
}
}
func resetGlobals() {
moduleLibDir = defaultModuleDir
moduleRoot = getModuleRoot()
}
func createFiles(t *testing.T) (string, []module) {
writeFile := func(path, text string) {
if err := ioutil.WriteFile(path, []byte(text), 0644); err != nil {
t.Fatal(err)
}
}
var u unix.Utsname
if err := unix.Uname(&u); err != nil {
t.Fatal(err)
}
moduleLibDir = t.TempDir()
moduleRoot = getModuleRoot()
if err := os.Mkdir(moduleRoot, 0755); err != nil {
t.Fatal(err)
}
text := "kernel/a/foo.ko:\n"
text += "kernel/a/bar.ko: kernel/a/foo.ko\n"
text += "kernel/a/baz.ko: kernel/a/bar.ko kernel/a/foo.ko\n"
writeFile(filepath.Join(moduleRoot, "/modules.dep"), text)
text = "kernel/a/foo_bi.ko\n"
text += "kernel/a/bar-bi.ko.gz\n"
writeFile(filepath.Join(moduleRoot, "/modules.builtin"), text)
modules := []module{
{name: "foo", path: "kernel/a/foo.ko"},
{name: "bar", path: "kernel/a/bar.ko"},
{name: "baz", path: "kernel/a/baz.ko"},
}
return moduleLibDir, modules
}
func getRandomLoadedModule(t *testing.T) (string, error) {
f, err := os.Open("/proc/modules")
if err != nil {
return "", err
}
defer func() {
err := f.Close()
if err != nil {
t.Logf("failed closing /proc/modules file, %v", err)
}
}()
lines, err := lineCounter(f)
if err != nil {
return "", err
}
counter := 1
midLine := lines / 2
modName := ""
scanner := bufio.NewScanner(f)
for scanner.Scan() {
fields := strings.Fields(scanner.Text())
if counter == midLine {
if fields[4] == "Unloading" {
continue
}
modName = fields[0]
break
}
counter++
}
if scanner.Err() != nil {
return "", scanner.Err()
}
return modName, nil
}
func lineCounter(r io.Reader) (int, error) {
buf := make([]byte, 32*1024)
count := 0
lineSep := []byte{'\n'}
for {
c, err := r.Read(buf)
count += bytes.Count(buf[:c], lineSep)
switch {
case err == io.EOF:
return count, nil
case err != nil:
return count, err
}
}
}

288
iface/udp_mux.go Normal file
View File

@@ -0,0 +1,288 @@
package iface
import (
"fmt"
log "github.com/sirupsen/logrus"
"io"
"net"
"strings"
"sync"
"github.com/pion/logging"
"github.com/pion/stun"
)
const receiveMTU = 8192
// UDPMux allows multiple connections to go over a single UDP port
type UDPMux interface {
io.Closer
GetConn(ufrag string) (net.PacketConn, error)
RemoveConnByUfrag(ufrag string)
}
// UDPMuxDefault is an implementation of the interface
type UDPMuxDefault struct {
params UDPMuxParams
closedChan chan struct{}
closeOnce sync.Once
// conns is a map of all udpMuxedConn indexed by ufrag|network|candidateType
conns map[string]*udpMuxedConn
addressMapMu sync.RWMutex
addressMap map[string][]*udpMuxedConn
// buffer pool to recycle buffers for net.UDPAddr encodes/decodes
pool *sync.Pool
mu sync.Mutex
}
const maxAddrSize = 512
// UDPMuxParams are parameters for UDPMux.
type UDPMuxParams struct {
Logger logging.LeveledLogger
UDPConn net.PacketConn
}
// NewUDPMuxDefault creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}
return &UDPMuxDefault{
addressMap: map[string][]*udpMuxedConn{},
params: params,
conns: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1),
pool: &sync.Pool{
New: func() interface{} {
// big enough buffer to fit both packet and address
return newBufferHolder(receiveMTU + maxAddrSize)
},
},
}
}
func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
remoteAddr, ok := addr.(*net.UDPAddr)
if !ok {
return fmt.Errorf("underlying PacketConn did not return a UDPAddr")
}
// If we have already seen this address dispatch to the appropriate destination
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections.
m.addressMapMu.Lock()
var destinationConnList []*udpMuxedConn
if storedConns, ok := m.addressMap[addr.String()]; ok {
for _, conn := range storedConns {
destinationConnList = append(destinationConnList, conn)
}
}
m.addressMapMu.Unlock()
// This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront.
// However, we can take a username attribute from the STUN message which contains ufrag.
// We can use ufrag to identify the destination conn to route packet to.
attr, stunAttrErr := msg.Get(stun.AttrUsername)
if stunAttrErr == nil {
ufrag := strings.Split(string(attr), ":")[0]
m.mu.Lock()
if destinationConn, ok := m.conns[ufrag]; ok {
exists := false
for _, conn := range destinationConnList {
if conn.params.Key == destinationConn.params.Key {
exists = true
break
}
}
if !exists {
destinationConnList = append(destinationConnList, destinationConn)
}
}
m.mu.Unlock()
}
// Forward STUN packets to each destination connections even thought the STUN packet might not belong there.
// It will be discarded by the further ICE candidate logic if so.
for _, conn := range destinationConnList {
if err := conn.writePacket(msg.Raw, remoteAddr); err != nil {
log.Errorf("could not write packet: %v", err)
}
}
return nil
}
// LocalAddr returns the listening address of this UDPMuxDefault
func (m *UDPMuxDefault) LocalAddr() net.Addr {
return m.params.UDPConn.LocalAddr()
}
// GetConn returns a PacketConn given the connection's ufrag and network
// creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) {
m.mu.Lock()
defer m.mu.Unlock()
log.Debugf("ICE: getting muxed connection for %s", ufrag)
if m.IsClosed() {
return nil, io.ErrClosedPipe
}
if c, ok := m.conns[ufrag]; ok {
return c, nil
}
c := m.createMuxedConn(ufrag)
go func() {
<-c.CloseChannel()
m.removeConn(ufrag)
}()
m.conns[ufrag] = c
return c, nil
}
// RemoveConnByUfrag stops and removes the muxed packet connection
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
m.mu.Lock()
removedConns := make([]*udpMuxedConn, 0)
for key := range m.conns {
if key != ufrag {
continue
}
c := m.conns[key]
delete(m.conns, key)
if c != nil {
removedConns = append(removedConns, c)
}
}
// keep lock section small to avoid deadlock with conn lock
m.mu.Unlock()
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
for _, c := range removedConns {
addresses := c.getAddresses()
for _, addr := range addresses {
if connList, ok := m.addressMap[addr]; ok {
var newList []*udpMuxedConn
for _, conn := range connList {
if conn.params.Key != ufrag {
newList = append(newList, conn)
}
}
m.addressMap[addr] = newList
}
}
}
}
// IsClosed returns true if the mux had been closed
func (m *UDPMuxDefault) IsClosed() bool {
select {
case <-m.closedChan:
return true
default:
return false
}
}
// Close the mux, no further connections could be created
func (m *UDPMuxDefault) Close() error {
var err error
m.closeOnce.Do(func() {
m.mu.Lock()
defer m.mu.Unlock()
for _, c := range m.conns {
_ = c.Close()
}
m.conns = make(map[string]*udpMuxedConn)
close(m.closedChan)
})
return err
}
func (m *UDPMuxDefault) removeConn(key string) {
m.mu.Lock()
c := m.conns[key]
delete(m.conns, key)
// keep lock section small to avoid deadlock with conn lock
m.mu.Unlock()
if c == nil {
return
}
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
addresses := c.getAddresses()
for _, addr := range addresses {
if connList, ok := m.addressMap[addr]; ok {
var newList []*udpMuxedConn
for _, conn := range connList {
if conn.params.Key != key {
newList = append(newList, conn)
}
}
m.addressMap[addr] = newList
}
}
}
func (m *UDPMuxDefault) writeTo(buf []byte, raddr net.Addr) (n int, err error) {
return m.params.UDPConn.WriteTo(buf, raddr)
}
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) {
if m.IsClosed() {
return
}
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
existing, ok := m.addressMap[addr]
if !ok {
existing = []*udpMuxedConn{}
}
existing = append(existing, conn)
m.addressMap[addr] = existing
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
}
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
c := newUDPMuxedConn(&udpMuxedConnParams{
Mux: m,
Key: key,
AddrPool: m.pool,
LocalAddr: m.LocalAddr(),
Logger: m.params.Logger,
})
log.Debugf("ICE: created muxed connection %s for %s", c.LocalAddr().String(), key)
return c
}
type bufferHolder struct {
buffer []byte
}
func newBufferHolder(size int) *bufferHolder {
return &bufferHolder{
buffer: make([]byte, size),
}
}

235
iface/udp_mux_universal.go Normal file
View File

@@ -0,0 +1,235 @@
package iface
import (
"errors"
"fmt"
log "github.com/sirupsen/logrus"
"net"
"time"
"github.com/pion/logging"
"github.com/pion/stun"
)
// UniversalUDPMux allows multiple connections to go over a single UDP port for
// host, server reflexive and relayed candidates.
// Actual connection muxing is happening in the UDPMux.
type UniversalUDPMux interface {
UDPMux
GetXORMappedAddr(stunAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error)
GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error)
GetConnForURL(ufrag string, url string) (net.PacketConn, error)
}
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn overriding ReadFrom.
// It the passes packets to the UDPMux that does the actual connection muxing.
type UniversalUDPMuxDefault struct {
*UDPMuxDefault
params UniversalUDPMuxParams
// since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
// stun.XORMappedAddress indexed by the STUN server addr
xorMappedMap map[string]*xorMapped
}
// UniversalUDPMuxParams are parameters for UniversalUDPMux server reflexive.
type UniversalUDPMuxParams struct {
Logger logging.LeveledLogger
UDPConn net.PacketConn
XORMappedAddrCacheTTL time.Duration
}
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}
if params.XORMappedAddrCacheTTL == 0 {
params.XORMappedAddrCacheTTL = time.Second * 25
}
m := &UniversalUDPMuxDefault{
params: params,
xorMappedMap: make(map[string]*xorMapped),
}
// embed UDPMux
udpMuxParams := UDPMuxParams{
Logger: params.Logger,
UDPConn: m.params.UDPConn,
}
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
return m
}
// GetRelayedAddr creates relayed connection to the given TURN service and returns the relayed addr.
// Not implemented yet.
func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error) {
return nil, errors.New("not implemented yet")
}
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server.
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url))
}
func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
// message about this err will be logged in the UDPMux
return nil
}
if m.isXORMappedResponse(msg, udpAddr.String()) {
err := m.handleXORMappedResponse(udpAddr, msg)
if err != nil {
log.Debugf("%w: %v", errors.New("failed to get XOR-MAPPED-ADDRESS response"), err)
return nil
}
return nil
}
return m.UDPMuxDefault.HandleSTUNMessage(msg, addr)
}
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.
func (m *UniversalUDPMuxDefault) isXORMappedResponse(msg *stun.Message, stunAddr string) bool {
m.mu.Lock()
defer m.mu.Unlock()
// check first if it is a STUN server address because remote peer can also send similar messages but as a BindingSuccess
_, ok := m.xorMappedMap[stunAddr]
_, err := msg.Get(stun.AttrXORMappedAddress)
return err == nil && ok
}
// handleXORMappedResponse parses response from the STUN server, extracts XORMappedAddress attribute
// and set the mapped address for the server
func (m *UniversalUDPMuxDefault) handleXORMappedResponse(stunAddr *net.UDPAddr, msg *stun.Message) error {
m.mu.Lock()
defer m.mu.Unlock()
mappedAddr, ok := m.xorMappedMap[stunAddr.String()]
if !ok {
return errors.New("no address mapping")
}
var addr stun.XORMappedAddress
if err := addr.GetFrom(msg); err != nil {
return err
}
m.xorMappedMap[stunAddr.String()] = mappedAddr
mappedAddr.SetAddr(&addr)
return nil
}
// GetXORMappedAddr returns *stun.XORMappedAddress if already present for a given STUN server.
// Makes a STUN binding request to discover mapped address otherwise.
// Blocks until the stun.XORMappedAddress has been discovered or deadline.
// Method is safe for concurrent use.
func (m *UniversalUDPMuxDefault) GetXORMappedAddr(serverAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) {
m.mu.Lock()
mappedAddr, ok := m.xorMappedMap[serverAddr.String()]
// if we already have a mapping for this STUN server (address already received)
// and if it is not too old we return it without making a new request to STUN server
if ok {
if mappedAddr.expired() {
mappedAddr.closeWaiters()
delete(m.xorMappedMap, serverAddr.String())
ok = false
} else if mappedAddr.pending() {
ok = false
}
}
m.mu.Unlock()
if ok {
return mappedAddr.addr, nil
}
// otherwise, make a STUN request to discover the address
// or wait for already sent request to complete
waitAddrReceived, err := m.sendStun(serverAddr)
if err != nil {
return nil, errors.New("failed to send STUN packet")
}
// block until response was handled by the connWorker routine and XORMappedAddress was updated
select {
case <-waitAddrReceived:
// when channel closed, addr was obtained
m.mu.Lock()
mappedAddr := *m.xorMappedMap[serverAddr.String()]
m.mu.Unlock()
if mappedAddr.addr == nil {
return nil, errors.New("no address mapping")
}
return mappedAddr.addr, nil
case <-time.After(deadline):
return nil, errors.New("timeout while waiting for XORMappedAddr")
}
}
// sendStun sends a STUN request via UDP conn.
//
// The returned channel is closed when the STUN response has been received.
// Method is safe for concurrent use.
func (m *UniversalUDPMuxDefault) sendStun(serverAddr net.Addr) (chan struct{}, error) {
m.mu.Lock()
defer m.mu.Unlock()
// if record present in the map, we already sent a STUN request,
// just wait when waitAddrReceived will be closed
addrMap, ok := m.xorMappedMap[serverAddr.String()]
if !ok {
addrMap = &xorMapped{
expiresAt: time.Now().Add(m.params.XORMappedAddrCacheTTL),
waitAddrReceived: make(chan struct{}),
}
m.xorMappedMap[serverAddr.String()] = addrMap
}
req, err := stun.Build(stun.BindingRequest, stun.TransactionID)
if err != nil {
return nil, err
}
if _, err = m.params.UDPConn.WriteTo(req.Raw, serverAddr); err != nil {
return nil, err
}
return addrMap.waitAddrReceived, nil
}
type xorMapped struct {
addr *stun.XORMappedAddress
waitAddrReceived chan struct{}
expiresAt time.Time
}
func (a *xorMapped) closeWaiters() {
select {
case <-a.waitAddrReceived:
// notify was close, ok, that means we received duplicate response
// just exit
break
default:
// notify tha twe have a new addr
close(a.waitAddrReceived)
}
}
func (a *xorMapped) pending() bool {
return a.addr == nil
}
func (a *xorMapped) expired() bool {
return a.expiresAt.Before(time.Now())
}
func (a *xorMapped) SetAddr(addr *stun.XORMappedAddress) {
a.addr = addr
a.closeWaiters()
}

246
iface/udp_muxed_conn.go Normal file
View File

@@ -0,0 +1,246 @@
package iface
import (
"encoding/binary"
"io"
"net"
"sync"
"time"
"github.com/pion/logging"
"github.com/pion/transport/packetio"
)
type udpMuxedConnParams struct {
Mux *UDPMuxDefault
AddrPool *sync.Pool
Key string
LocalAddr net.Addr
Logger logging.LeveledLogger
}
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
type udpMuxedConn struct {
params *udpMuxedConnParams
// remote addresses that we have sent to on this conn
addresses []string
// channel holding incoming packets
buffer *packetio.Buffer
closedChan chan struct{}
closeOnce sync.Once
mu sync.Mutex
}
func newUDPMuxedConn(params *udpMuxedConnParams) *udpMuxedConn {
p := &udpMuxedConn{
params: params,
buffer: packetio.NewBuffer(),
closedChan: make(chan struct{}),
}
return p
}
func (c *udpMuxedConn) ReadFrom(b []byte) (n int, raddr net.Addr, err error) {
buf := c.params.AddrPool.Get().(*bufferHolder)
defer c.params.AddrPool.Put(buf)
// read address
total, err := c.buffer.Read(buf.buffer)
if err != nil {
return 0, nil, err
}
dataLen := int(binary.LittleEndian.Uint16(buf.buffer[:2]))
if dataLen > total || dataLen > len(b) {
return 0, nil, io.ErrShortBuffer
}
// read data and then address
offset := 2
copy(b, buf.buffer[offset:offset+dataLen])
offset += dataLen
// read address len & decode address
addrLen := int(binary.LittleEndian.Uint16(buf.buffer[offset : offset+2]))
offset += 2
if raddr, err = decodeUDPAddr(buf.buffer[offset : offset+addrLen]); err != nil {
return 0, nil, err
}
return dataLen, raddr, nil
}
func (c *udpMuxedConn) WriteTo(buf []byte, raddr net.Addr) (n int, err error) {
if c.isClosed() {
return 0, io.ErrClosedPipe
}
// each time we write to a new address, we'll register it with the mux
addr := raddr.String()
if !c.containsAddress(addr) {
c.addAddress(addr)
}
return c.params.Mux.writeTo(buf, raddr)
}
func (c *udpMuxedConn) LocalAddr() net.Addr {
return c.params.LocalAddr
}
func (c *udpMuxedConn) SetDeadline(tm time.Time) error {
return nil
}
func (c *udpMuxedConn) SetReadDeadline(tm time.Time) error {
return nil
}
func (c *udpMuxedConn) SetWriteDeadline(tm time.Time) error {
return nil
}
func (c *udpMuxedConn) CloseChannel() <-chan struct{} {
return c.closedChan
}
func (c *udpMuxedConn) Close() error {
var err error
c.closeOnce.Do(func() {
err = c.buffer.Close()
close(c.closedChan)
})
c.mu.Lock()
defer c.mu.Unlock()
c.addresses = nil
return err
}
func (c *udpMuxedConn) isClosed() bool {
select {
case <-c.closedChan:
return true
default:
return false
}
}
func (c *udpMuxedConn) getAddresses() []string {
c.mu.Lock()
defer c.mu.Unlock()
addresses := make([]string, len(c.addresses))
copy(addresses, c.addresses)
return addresses
}
func (c *udpMuxedConn) addAddress(addr string) {
c.mu.Lock()
c.addresses = append(c.addresses, addr)
c.mu.Unlock()
// map it on mux
c.params.Mux.registerConnForAddress(c, addr)
}
func (c *udpMuxedConn) removeAddress(addr string) {
c.mu.Lock()
defer c.mu.Unlock()
newAddresses := make([]string, 0, len(c.addresses))
for _, a := range c.addresses {
if a != addr {
newAddresses = append(newAddresses, a)
}
}
c.addresses = newAddresses
}
func (c *udpMuxedConn) containsAddress(addr string) bool {
c.mu.Lock()
defer c.mu.Unlock()
for _, a := range c.addresses {
if addr == a {
return true
}
}
return false
}
func (c *udpMuxedConn) writePacket(data []byte, addr *net.UDPAddr) error {
// write two packets, address and data
buf := c.params.AddrPool.Get().(*bufferHolder)
defer c.params.AddrPool.Put(buf)
// format of buffer | data len | data bytes | addr len | addr bytes |
if len(buf.buffer) < len(data)+maxAddrSize {
return io.ErrShortBuffer
}
// data len
binary.LittleEndian.PutUint16(buf.buffer, uint16(len(data)))
offset := 2
// data
copy(buf.buffer[offset:], data)
offset += len(data)
// write address first, leaving room for its length
n, err := encodeUDPAddr(addr, buf.buffer[offset+2:])
if err != nil {
return nil
}
total := offset + n + 2
// address len
binary.LittleEndian.PutUint16(buf.buffer[offset:], uint16(n))
if _, err := c.buffer.Write(buf.buffer[:total]); err != nil {
return err
}
return nil
}
func encodeUDPAddr(addr *net.UDPAddr, buf []byte) (int, error) {
ipdata, err := addr.IP.MarshalText()
if err != nil {
return 0, err
}
total := 2 + len(ipdata) + 2 + len(addr.Zone)
if total > len(buf) {
return 0, io.ErrShortBuffer
}
binary.LittleEndian.PutUint16(buf, uint16(len(ipdata)))
offset := 2
n := copy(buf[offset:], ipdata)
offset += n
binary.LittleEndian.PutUint16(buf[offset:], uint16(addr.Port))
offset += 2
copy(buf[offset:], addr.Zone)
return total, nil
}
func decodeUDPAddr(buf []byte) (*net.UDPAddr, error) {
addr := net.UDPAddr{}
offset := 0
ipLen := int(binary.LittleEndian.Uint16(buf[:2]))
offset += 2
// basic bounds checking
if ipLen+offset > len(buf) {
return nil, io.ErrShortBuffer
}
if err := addr.IP.UnmarshalText(buf[offset : offset+ipLen]); err != nil {
return nil, err
}
offset += ipLen
addr.Port = int(binary.LittleEndian.Uint16(buf[offset : offset+2]))
offset += 2
zone := make([]byte, len(buf[offset:]))
copy(zone, buf[offset:])
addr.Zone = string(zone)
return &addr, nil
}

View File

@@ -10,8 +10,6 @@ NETBIRD_MGMT_API_ENDPOINT=https://$NETBIRD_DOMAIN:$NETBIRD_MGMT_API_PORT
NETBIRD_MGMT_API_CERT_FILE="/etc/letsencrypt/live/$NETBIRD_DOMAIN/fullchain.pem"
# Management Certficate key file path.
NETBIRD_MGMT_API_CERT_KEY_FILE="/etc/letsencrypt/live/$NETBIRD_DOMAIN/privkey.pem"
# By default Management single account mode is enabled and domain set to $NETBIRD_DOMAIN, you may want to set this to your user's email domain
NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN=$NETBIRD_DOMAIN
# Turn credentials
@@ -31,8 +29,6 @@ LETSENCRYPT_VOLUMESUFFIX="letsencrypt"
NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="none"
NETBIRD_DISABLE_ANONYMOUS_METRICS=${NETBIRD_DISABLE_ANONYMOUS_METRICS:-false}
# exports
export NETBIRD_DOMAIN
export NETBIRD_AUTH_CLIENT_ID
@@ -49,8 +45,6 @@ export NETBIRD_MGMT_API_CERT_KEY_FILE
export NETBIRD_AUTH_DEVICE_AUTH_PROVIDER
export NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID
export NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT
export NETBIRD_AUTH_REDIRECT_URI
export NETBIRD_AUTH_SILENT_REDIRECT_URI
export TURN_USER
export TURN_PASSWORD
export TURN_MIN_PORT
@@ -59,4 +53,3 @@ export VOLUME_PREFIX
export MGMT_VOLUMESUFFIX
export SIGNAL_VOLUMESUFFIX
export LETSENCRYPT_VOLUMESUFFIX
export NETBIRD_DISABLE_ANONYMOUS_METRICS

View File

@@ -18,8 +18,6 @@ services:
- NGINX_SSL_PORT=443
- LETSENCRYPT_DOMAIN=$NETBIRD_DOMAIN
- LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL
- AUTH_REDIRECT_URI=$NETBIRD_AUTH_REDIRECT_URI
- AUTH_SILENT_REDIRECT_URI=$NETBIRD_AUTH_SILENT_REDIRECT_URI
volumes:
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt/
# Signal
@@ -48,7 +46,7 @@ services:
# # port and command for Let's Encrypt validation without dashboard container
# - 443:443
# command: ["--letsencrypt-domain", "$NETBIRD_DOMAIN", "--log-file", "console"]
command: ["--port", "443", "--log-file", "console", "--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS", "--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN"]
command: ["--port", "443", "--log-file", "console"]
# Coturn
coturn:
image: coturn/coturn

View File

@@ -12,11 +12,4 @@ NETBIRD_USE_AUTH0="false"
NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="none"
NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID=""
# e.g. hello@mydomain.com
NETBIRD_LETSENCRYPT_EMAIL=""
# if your IDP provider doesn't support fragmented URIs, configure custom
# redirect and silent redirect URIs, these will be concatenated into your NETBIRD_DOMAIN domain.
# NETBIRD_AUTH_REDIRECT_URI="/peers"
# NETBIRD_AUTH_SILENT_REDIRECT_URI="/add-peers"
# Disable anonymous metrics collection, see more information at https://netbird.io/docs/FAQ/metrics-collection
NETBIRD_DISABLE_ANONYMOUS_METRICS=false
NETBIRD_LETSENCRYPT_EMAIL=""

View File

@@ -10,5 +10,4 @@ NETBIRD_AUTH_CLIENT_ID=$CI_NETBIRD_AUTH_CLIENT_ID
NETBIRD_USE_AUTH0=$CI_NETBIRD_USE_AUTH0
NETBIRD_AUTH_AUDIENCE=$CI_NETBIRD_AUTH_AUDIENCE
# e.g. hello@mydomain.com
NETBIRD_LETSENCRYPT_EMAIL=""
NETBIRD_AUTH_REDIRECT_URI="/peers"
NETBIRD_LETSENCRYPT_EMAIL=""

View File

@@ -55,12 +55,12 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
}
peersUpdateManager := mgmt.NewPeersUpdateManager()
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "")
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil)
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager)
if err != nil {
t.Fatal(err)
}

View File

@@ -109,9 +109,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
return err
}
ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey)
cancel, stream, err := c.connectToStream(*serverPubKey)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
@@ -119,6 +117,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
}
return err
}
defer cancel()
log.Infof("connected to the Management Service stream")
@@ -147,7 +146,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
return nil
}
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
func (c *GrpcClient) connectToStream(serverPubKey wgtypes.Key) (context.CancelFunc, proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{}
myPrivateKey := c.key
@@ -156,14 +155,16 @@ func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.K
encryptedReq, err := encryption.EncryptMessage(serverPubKey, myPrivateKey, req)
if err != nil {
log.Errorf("failed encrypting message: %s", err)
return nil, err
return nil, nil, err
}
ctx, cancel := context.WithCancel(c.ctx)
syncReq := &proto.EncryptedMessage{WgPubKey: myPublicKey.String(), Body: encryptedReq}
sync, err := c.realClient.Sync(ctx, syncReq)
if err != nil {
return nil, err
cancel()
return nil, nil, err
}
return sync, nil
return cancel, sync, nil
}
func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error {

View File

@@ -1,16 +1,12 @@
package cmd
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"flag"
"fmt"
"github.com/google/uuid"
httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/telemetry"
"golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
@@ -42,13 +38,11 @@ import (
const ManagementLegacyPort = 33073
var (
mgmtPort int
mgmtMetricsPort int
mgmtLetsencryptDomain string
mgmtSingleAccModeDomain string
certFile string
certKey string
config *server.Config
mgmtPort int
mgmtLetsencryptDomain string
certFile string
certKey string
config *server.Config
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
@@ -116,27 +110,15 @@ var (
}
peersUpdateManager := server.NewPeersUpdateManager()
appMetrics, err := telemetry.NewDefaultAppMetrics(cmd.Context())
if err != nil {
return err
}
err = appMetrics.Expose(mgmtMetricsPort, "/metrics")
if err != nil {
return err
}
var idpManager idp.Manager
if config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(*config.IdpManagerConfig, appMetrics)
idpManager, err = idp.NewManager(*config.IdpManagerConfig)
if err != nil {
return fmt.Errorf("failed retrieving a new idp manager with err: %v", err)
}
}
if disableSingleAccMode {
mgmtSingleAccModeDomain = ""
}
accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain)
accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager)
if err != nil {
return fmt.Errorf("failed to build default manager: %v", err)
}
@@ -166,34 +148,19 @@ var (
tlsEnabled = true
}
httpAPIHandler, err := httpapi.APIHandler(accountManager, config.HttpConfig.AuthIssuer,
config.HttpConfig.AuthAudience, config.HttpConfig.AuthKeysLocation, appMetrics)
httpAPIHandler, err := httpapi.APIHandler(accountManager,
config.HttpConfig.AuthIssuer, config.HttpConfig.AuthAudience, config.HttpConfig.AuthKeysLocation)
if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, appMetrics)
srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager)
if err != nil {
return fmt.Errorf("failed creating gRPC API handler: %v", err)
}
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
installationID, err := getInstallationID(store)
if err != nil {
log.Errorf("cannot load TLS credentials: %v", err)
return err
}
fmt.Println("metrics ", disableMetrics)
if !disableMetrics {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
metricsWorker := metrics.NewWorker(ctx, installationID, store, peersUpdateManager)
go metricsWorker.Run()
}
var compatListener net.Listener
if mgmtPort != ManagementLegacyPort {
// The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it
@@ -240,7 +207,6 @@ var (
SetupCloseHandler()
<-stopCh
_ = appMetrics.Close()
_ = listener.Close()
if certManager != nil {
_ = certManager.Listener().Close()
@@ -262,20 +228,6 @@ func notifyStop(msg string) {
}
}
func getInstallationID(store server.Store) (string, error) {
installationID := store.GetInstallationID()
if installationID != "" {
return installationID, nil
}
installationID = strings.ToUpper(uuid.New().String())
err := store.SaveInstallationID(installationID)
if err != nil {
return "", err
}
return installationID, nil
}
func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
@@ -443,7 +395,7 @@ func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
return nil, err
}
// NewDefaultAppMetrics the credentials and return it
// Create the credentials and return it
config := &tls.Config{
Certificates: []tls.Certificate{serverCert},
ClientAuth: tls.NoClientCert,

View File

@@ -13,23 +13,20 @@ const (
)
var (
defaultMgmtConfigDir string
defaultMgmtDataDir string
defaultMgmtConfig string
defaultSingleAccModeDomain string
defaultLogDir string
defaultLogFile string
oldDefaultMgmtConfigDir string
oldDefaultMgmtDataDir string
oldDefaultMgmtConfig string
oldDefaultLogDir string
oldDefaultLogFile string
mgmtDataDir string
mgmtConfig string
logLevel string
logFile string
disableMetrics bool
disableSingleAccMode bool
defaultMgmtConfigDir string
defaultMgmtDataDir string
defaultMgmtConfig string
defaultLogDir string
defaultLogFile string
oldDefaultMgmtConfigDir string
oldDefaultMgmtDataDir string
oldDefaultMgmtConfig string
oldDefaultLogDir string
oldDefaultLogFile string
mgmtDataDir string
mgmtConfig string
logLevel string
logFile string
rootCmd = &cobra.Command{
Use: "netbird-mgmt",
@@ -50,7 +47,6 @@ func init() {
stopCh = make(chan int)
defaultMgmtDataDir = "/var/lib/netbird/"
defaultSingleAccModeDomain = "netbird.selfhosted"
defaultMgmtConfigDir = "/etc/netbird"
defaultLogDir = "/var/log/netbird"
@@ -65,15 +61,11 @@ func init() {
oldDefaultLogFile = oldDefaultLogDir + "/management.log"
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 8081, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
mgmtCmd.Flags().StringVar(&mgmtConfig, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
mgmtCmd.Flags().StringVar(&mgmtSingleAccModeDomain, "single-account-mode-domain", defaultSingleAccModeDomain, "Enables single account mode. This means that all the users will be under the same account grouped by the specified domain. If the installation has more than one account, the property is ineffective. Enabled by default with the default domain "+defaultSingleAccModeDomain)
mgmtCmd.Flags().BoolVar(&disableSingleAccMode, "disable-single-account-mode", false, "If set to true, disables single account mode. The --single-account-mode-domain property will be ignored and every new user will have a separate NetBird account.")
mgmtCmd.Flags().StringVar(&certFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird")
rootCmd.MarkFlagRequired("config") //nolint
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "")

View File

@@ -1,17 +1,4 @@
#!/bin/bash
set -e
if ! which realpath > /dev/null 2>&1
then
echo realpath is not installed
echo run: brew install coreutils
exit 1
fi
old_pwd=$(pwd)
script_path=$(dirname $(realpath "$0"))
cd "$script_path"
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
protoc -I ./ ./management.proto --go_out=../ --go-grpc_out=../
cd "$old_pwd"
protoc -I proto/ proto/management.proto --go_out=. --go-grpc_out=.

View File

@@ -3,9 +3,8 @@ package server
import (
"context"
"fmt"
"github.com/eko/gocache/v3/cache"
cacheStore "github.com/eko/gocache/v3/store"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/eko/gocache/v2/cache"
cacheStore "github.com/eko/gocache/v2/store"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/route"
@@ -16,7 +15,6 @@ import (
"google.golang.org/grpc/status"
"math/rand"
"reflect"
"regexp"
"strings"
"sync"
"time"
@@ -30,33 +28,23 @@ const (
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
)
func cacheEntryExpiration() time.Duration {
r := rand.Intn(int(CacheExpirationMax.Milliseconds()-CacheExpirationMin.Milliseconds())) + int(CacheExpirationMin.Milliseconds())
return time.Duration(r) * time.Millisecond
}
type AccountManager interface {
GetOrCreateAccountByUser(userId, domain string) (*Account, error)
GetAccountByUser(userId string) (*Account, error)
CreateSetupKey(
AddSetupKey(
accountId string,
keyName string,
keyType SetupKeyType,
expiresIn time.Duration,
autoGroups []string,
) (*SetupKey, error)
SaveSetupKey(accountID string, key *SetupKey) (*SetupKey, error)
CreateUser(accountID string, key *UserInfo) (*UserInfo, error)
ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
SaveUser(accountID string, key *User) (*UserInfo, error)
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
RevokeSetupKey(accountId string, keyId string) (*SetupKey, error)
RenameSetupKey(accountId string, keyId string, newName string) (*SetupKey, error)
GetAccountById(accountId string) (*Account, error)
GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error)
GetAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error)
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExists(accountId string) (*bool, error)
GetPeer(peerKey string) (*Peer, error)
GetPeers(accountID, userID string) ([]*Peer, error)
MarkPeerConnected(peerKey string, connected bool) error
RenamePeer(accountId string, peerKey string, newName string) (*Peer, error)
DeletePeer(accountId string, peerKey string) (*Peer, error)
@@ -67,7 +55,7 @@ type AccountManager interface {
AddPeer(setupKey string, userId string, peer *Peer) (*Peer, error)
UpdatePeerMeta(peerKey string, meta PeerSystemMeta) error
UpdatePeerSSHKey(peerKey string, sshKey string) error
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
GetUsersFromAccount(accountId string) ([]*UserInfo, error)
GetGroup(accountId, groupID string) (*Group, error)
SaveGroup(accountId string, group *Group) error
UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error)
@@ -76,44 +64,27 @@ type AccountManager interface {
GroupAddPeer(accountId, groupID, peerKey string) error
GroupDeletePeer(accountId, groupID, peerKey string) error
GroupListPeers(accountId, groupID string) ([]*Peer, error)
GetRule(accountID, ruleID, userID string) (*Rule, error)
GetRule(accountId, ruleID string) (*Rule, error)
SaveRule(accountID string, rule *Rule) error
UpdateRule(accountID string, ruleID string, operations []RuleUpdateOperation) (*Rule, error)
DeleteRule(accountId, ruleID string) error
ListRules(accountID, userID string) ([]*Rule, error)
GetRoute(accountID, routeID, userID string) (*route.Route, error)
ListRules(accountId string) ([]*Rule, error)
GetRoute(accountID, routeID string) (*route.Route, error)
CreateRoute(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error)
SaveRoute(accountID string, route *route.Route) error
UpdateRoute(accountID string, routeID string, operations []RouteUpdateOperation) (*route.Route, error)
DeleteRoute(accountID, routeID string) error
ListRoutes(accountID, userID string) ([]*route.Route, error)
GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroup(accountID string, nsGroupToSave *nbdns.NameServerGroup) error
UpdateNameServerGroup(accountID, nsGroupID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
DeleteNameServerGroup(accountID, nsGroupID string) error
ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error)
ListRoutes(accountID string) ([]*route.Route, error)
}
type DefaultAccountManager struct {
Store Store
// mux to synchronise account operations (e.g. generating Peer IP address inside the Network)
mux sync.Mutex
// cacheMux and cacheLoading helps to make sure that only a single cache reload runs at a time per accountID
cacheMux sync.Mutex
// cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded
cacheLoading map[string]chan struct{}
// mutex to synchronise account operations (e.g. generating Peer IP address inside the Network)
mux sync.Mutex
peersUpdateManager *PeersUpdateManager
idpManager idp.Manager
cacheManager cache.CacheInterface[[]*idp.UserData]
cacheManager cache.CacheInterface
ctx context.Context
// singleAccountMode indicates whether the instance has a single account.
// If true, then every new user will end up under the same account.
// This value will be set to false if management service has more than one account.
singleAccountMode bool
// singleAccountModeDomain is a domain to use in singleAccountMode setup
singleAccountModeDomain string
}
// Account represents a unique account of the system
@@ -131,26 +102,13 @@ type Account struct {
Groups map[string]*Group
Rules map[string]*Rule
Routes map[string]*route.Route
NameServerGroups map[string]*nbdns.NameServerGroup
}
type UserInfo struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Role string `json:"role"`
AutoGroups []string `json:"auto_groups"`
Status string `json:"-"`
}
// FindUser looks for a given user in the Account or returns error if user wasn't found.
func (a *Account) FindUser(userID string) (*User, error) {
user := a.Users[userID]
if user == nil {
return nil, Errorf(UserNotFound, "user %s not found", userID)
}
return user, nil
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Role string `json:"role"`
}
func (a *Account) Copy() *Account {
@@ -179,27 +137,15 @@ func (a *Account) Copy() *Account {
rules[id] = rule.Copy()
}
routes := map[string]*route.Route{}
for id, route := range a.Routes {
routes[id] = route.Copy()
}
nsGroups := map[string]*nbdns.NameServerGroup{}
for id, nsGroup := range a.NameServerGroups {
nsGroups[id] = nsGroup.Copy()
}
return &Account{
Id: a.Id,
CreatedBy: a.CreatedBy,
SetupKeys: setupKeys,
Network: a.Network.Copy(),
Peers: peers,
Users: users,
Groups: groups,
Rules: rules,
Routes: routes,
NameServerGroups: nsGroups,
Id: a.Id,
CreatedBy: a.CreatedBy,
SetupKeys: setupKeys,
Network: a.Network.Copy(),
Peers: peers,
Users: users,
Groups: groups,
Rules: rules,
}
}
@@ -213,31 +159,21 @@ func (a *Account) GetGroupAll() (*Group, error) {
}
// BuildManager creates a new DefaultAccountManager with a provided Store
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
singleAccountModeDomain string) (*DefaultAccountManager, error) {
func BuildManager(
store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
) (*DefaultAccountManager, error) {
am := &DefaultAccountManager{
Store: store,
mux: sync.Mutex{},
peersUpdateManager: peersUpdateManager,
idpManager: idpManager,
ctx: context.Background(),
cacheMux: sync.Mutex{},
cacheLoading: map[string]chan struct{}{},
}
allAccounts := store.GetAllAccounts()
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
am.singleAccountMode = singleAccountModeDomain != "" && len(allAccounts) <= 1
if am.singleAccountMode {
am.singleAccountModeDomain = singleAccountModeDomain
log.Infof("single account mode enabled, accounts number %d", len(allAccounts))
} else {
log.Infof("single account mode disabled, accounts number %d", len(allAccounts))
}
// if account doesn't have a default group
// if account has not default group
// we create 'all' group and add all peers into it
// also we create default rule with source as destination
for _, account := range allAccounts {
for _, account := range store.GetAllAccounts() {
_, err := account.GetGroupAll()
if err != nil {
addAllGroup(account)
@@ -247,10 +183,10 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage
}
}
goCacheClient := gocache.New(CacheExpirationMax, 30*time.Minute)
goCacheStore := cacheStore.NewGoCache(goCacheClient)
gocacheClient := gocache.New(CacheExpirationMax, 30*time.Minute)
gocacheStore := cacheStore.NewGoCache(gocacheClient, nil)
am.cacheManager = cache.NewLoadable[[]*idp.UserData](am.loadAccount, cache.New[[]*idp.UserData](goCacheStore))
am.cacheManager = cache.NewLoadable(am.loadFromCache, cache.New(gocacheStore))
if !isNil(am.idpManager) {
go func() {
@@ -295,7 +231,11 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
}
for accountID, users := range userData {
err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration()))
rand.Seed(time.Now().UnixNano())
r := rand.Intn(int(CacheExpirationMax.Milliseconds()-CacheExpirationMin.Milliseconds())) + int(CacheExpirationMin.Milliseconds())
expiration := time.Duration(r) * time.Millisecond
err = am.cacheManager.Set(am.ctx, accountID, users, &cacheStore.Options{Expiration: expiration})
if err != nil {
return err
}
@@ -304,6 +244,93 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
return nil
}
// AddSetupKey generates a new setup key with a given name and type, and adds it to the specified account
func (am *DefaultAccountManager) AddSetupKey(
accountId string,
keyName string,
keyType SetupKeyType,
expiresIn time.Duration,
) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
keyDuration := DefaultSetupKeyDuration
if expiresIn != 0 {
keyDuration = expiresIn
}
account, err := am.Store.GetAccount(accountId)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
setupKey := GenerateSetupKey(keyName, keyType, keyDuration)
account.SetupKeys[setupKey.Key] = setupKey
err = am.Store.SaveAccount(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed adding account key")
}
return setupKey, nil
}
// RevokeSetupKey marks SetupKey as revoked - becomes not valid anymore
func (am *DefaultAccountManager) RevokeSetupKey(accountId string, keyId string) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountId)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
setupKey := getAccountSetupKeyById(account, keyId)
if setupKey == nil {
return nil, status.Errorf(codes.NotFound, "unknown setupKey %s", keyId)
}
keyCopy := setupKey.Copy()
keyCopy.Revoked = true
account.SetupKeys[keyCopy.Key] = keyCopy
err = am.Store.SaveAccount(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed adding account key")
}
return keyCopy, nil
}
// RenameSetupKey renames existing setup key of the specified account.
func (am *DefaultAccountManager) RenameSetupKey(
accountId string,
keyId string,
newName string,
) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountId)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
setupKey := getAccountSetupKeyById(account, keyId)
if setupKey == nil {
return nil, status.Errorf(codes.NotFound, "unknown setupKey %s", keyId)
}
keyCopy := setupKey.Copy()
keyCopy.Name = newName
account.SetupKeys[keyCopy.Key] = keyCopy
err = am.Store.SaveAccount(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed adding account key")
}
return keyCopy, nil
}
// GetAccountById returns an existing account using its ID or error (NotFound) if doesn't exist
func (am *DefaultAccountManager) GetAccountById(accountId string) (*Account, error) {
am.mux.Lock()
@@ -329,7 +356,7 @@ func (am *DefaultAccountManager) GetAccountByUserOrAccountId(
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userId)
}
err = am.addAccountIDToIDPAppMeta(userId, account)
err = am.updateIDPMetadata(userId, account.Id)
if err != nil {
return nil, err
}
@@ -343,28 +370,10 @@ func isNil(i idp.Manager) bool {
return i == nil || reflect.ValueOf(i).IsNil()
}
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(userID string, account *Account) error {
// updateIDPMetadata update user's app metadata in idp manager
func (am *DefaultAccountManager) updateIDPMetadata(userId, accountID string) error {
if !isNil(am.idpManager) {
// user can be nil if it wasn't found (e.g., just created)
user, err := am.lookupUserInCache(userID, account)
if err != nil {
return err
}
if user != nil && user.AppMetadata.WTAccountID == account.Id {
// it was already set, so we skip the unnecessary update
log.Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s",
account.Id, userID)
return nil
}
err = am.idpManager.UpdateUserAppMetadata(userID, idp.AppMetadata{WTAccountID: account.Id})
if err != nil {
return err
}
err := am.idpManager.UpdateUserAppMetadata(userId, idp.AppMetadata{WTAccountId: accountID})
if err != nil {
return status.Errorf(
codes.Internal,
@@ -372,113 +381,39 @@ func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(userID string, account
err,
)
}
// refresh cache to reflect the update
_, err = am.refreshCache(account.Id)
if err != nil {
return err
}
}
return nil
}
func (am *DefaultAccountManager) loadAccount(_ context.Context, accountID interface{}) ([]*idp.UserData, error) {
log.Debugf("account %s not found in cache, reloading", accountID)
func mergeLocalAndQueryUser(queried idp.UserData, local User) *UserInfo {
return &UserInfo{
ID: local.Id,
Email: queried.Email,
Name: queried.Name,
Role: string(local.Role),
}
}
func (am *DefaultAccountManager) loadFromCache(_ context.Context, accountID interface{}) (interface{}, error) {
return am.idpManager.GetAccount(fmt.Sprintf("%v", accountID))
}
func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountID string) (*idp.UserData, error) {
data, err := am.getAccountFromCache(accountID, false)
func (am *DefaultAccountManager) lookupCache(accountUsers map[string]*User, accountID string) ([]*idp.UserData, error) {
data, err := am.cacheManager.Get(am.ctx, accountID)
if err != nil {
return nil, err
}
for _, datum := range data {
if datum.Email == email {
return datum, nil
}
}
return nil, nil
}
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) {
users := make(map[string]struct{}, len(account.Users))
for _, user := range account.Users {
users[user.Id] = struct{}{}
}
log.Debugf("looking up user %s of account %s in cache", userID, account.Id)
userData, err := am.lookupCache(users, account.Id)
if err != nil {
return nil, err
}
for _, datum := range userData {
if datum.ID == userID {
return datum, nil
}
}
return nil, nil
}
func (am *DefaultAccountManager) refreshCache(accountID string) ([]*idp.UserData, error) {
return am.getAccountFromCache(accountID, true)
}
// getAccountFromCache returns user data for a given account ensuring that cache load happens only once
func (am *DefaultAccountManager) getAccountFromCache(accountID string, forceReload bool) ([]*idp.UserData, error) {
am.cacheMux.Lock()
loadingChan := am.cacheLoading[accountID]
if loadingChan == nil {
loadingChan = make(chan struct{})
am.cacheLoading[accountID] = loadingChan
am.cacheMux.Unlock()
defer func() {
am.cacheMux.Lock()
delete(am.cacheLoading, accountID)
close(loadingChan)
am.cacheMux.Unlock()
}()
if forceReload {
err := am.cacheManager.Delete(am.ctx, accountID)
if err != nil {
return nil, err
}
}
return am.cacheManager.Get(am.ctx, accountID)
}
am.cacheMux.Unlock()
log.Debugf("one request to get account %s is already running", accountID)
select {
case <-loadingChan:
// channel has been closed meaning cache was loaded => simply return from cache
return am.cacheManager.Get(am.ctx, accountID)
case <-time.After(5 * time.Second):
return nil, fmt.Errorf("timeout while waiting for account %s cache to reload", accountID)
}
}
func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, accountID string) ([]*idp.UserData, error) {
data, err := am.getAccountFromCache(accountID, false)
if err != nil {
return nil, err
}
userData := data.([]*idp.UserData)
userDataMap := make(map[string]struct{})
for _, datum := range data {
for _, datum := range userData {
userDataMap[datum.ID] = struct{}{}
}
// check whether we need to reload the cache
// the accountUsers ID list is the source of truth and all the users should be in the cache
reload := len(accountUsers) != len(data)
reload := len(accountUsers) != len(userData)
for user := range accountUsers {
if _, ok := userDataMap[user]; !ok {
reload = true
@@ -487,13 +422,59 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, a
if reload {
// reload cache once avoiding loops
data, err = am.refreshCache(accountID)
err := am.cacheManager.Delete(am.ctx, accountID)
if err != nil {
return nil, err
}
data, err = am.cacheManager.Get(am.ctx, accountID)
if err != nil {
return nil, err
}
userData = data.([]*idp.UserData)
}
return userData, err
}
// GetUsersFromAccount performs a batched request for users from IDP by account id
func (am *DefaultAccountManager) GetUsersFromAccount(accountID string) ([]*UserInfo, error) {
account, err := am.GetAccountById(accountID)
if err != nil {
return nil, err
}
queriedUsers := make([]*idp.UserData, 0)
if !isNil(am.idpManager) {
queriedUsers, err = am.lookupCache(account.Users, accountID)
if err != nil {
return nil, err
}
}
return data, err
userInfo := make([]*UserInfo, 0)
// in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo
if len(queriedUsers) == 0 {
for _, user := range account.Users {
userInfo = append(userInfo, &UserInfo{
ID: user.Id,
Email: "",
Name: "",
Role: string(user.Role),
})
}
return userInfo, nil
}
for _, queriedUser := range queriedUsers {
if localUser, contains := account.Users[queriedUser.ID]; contains {
userInfo = append(userInfo, mergeLocalAndQueryUser(*queriedUser, *localUser))
log.Debugf("Merged userinfo to send back; %v", userInfo)
}
}
return userInfo, nil
}
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account
@@ -523,6 +504,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(
// handleExistingUserAccount handles existing User accounts and update its domain attributes.
//
//
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
@@ -548,7 +530,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
}
// we should register the account ID to this user's metadata in our IDP manager
err = am.addAccountIDToIDPAppMeta(claims.UserId, existingAcc)
err = am.updateIDPMetadata(claims.UserId, existingAcc.Id)
if err != nil {
return err
}
@@ -586,7 +568,7 @@ func (am *DefaultAccountManager) handleNewUserAccount(
}
}
err = am.addAccountIDToIDPAppMeta(claims.UserId, account)
err = am.updateIDPMetadata(claims.UserId, account.Id)
if err != nil {
return nil, err
}
@@ -594,65 +576,7 @@ func (am *DefaultAccountManager) handleNewUserAccount(
return account, nil
}
// redeemInvite checks whether user has been invited and redeems the invite
func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) error {
// only possible with the enabled IdP manager
if am.idpManager == nil {
log.Warnf("invites only work with enabled IdP manager")
return nil
}
user, err := am.lookupUserInCache(userID, account)
if err != nil {
return err
}
if user == nil {
return status.Errorf(codes.NotFound, "user %s not found in the IdP", userID)
}
if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
log.Infof("redeeming invite for user %s account %s", userID, account.Id)
// User has already logged in, meaning that IdP should have set wt_pending_invite to false.
// Our job is to just reload cache.
go func() {
_, err = am.refreshCache(account.Id)
if err != nil {
log.Warnf("failed reloading cache when redeeming user %s under account %s", userID, account.Id)
return
}
log.Debugf("user %s of account %s redeemed invite", user.ID, account.Id)
}()
}
return nil
}
// GetAccountFromToken returns an account associated with this token
func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error) {
if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations.
// We override incoming domain claims to group users under a single account.
claims.Domain = am.singleAccountModeDomain
claims.DomainCategory = PrivateCategory
log.Infof("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
}
account, err := am.getAccountWithAuthorizationClaims(claims)
if err != nil {
return nil, err
}
err = am.redeemInvite(account, claims.UserId)
if err != nil {
return nil, err
}
return account, nil
}
// getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
// GetAccountWithAuthorizationClaims retrievs an account using JWT Claims.
// if domain is of the PrivateCategory category, it will evaluate
// if account is new, existing or if there is another account with the same domain
//
@@ -669,12 +593,12 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
//
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(
func (am *DefaultAccountManager) GetAccountWithAuthorizationClaims(
claims jwtclaims.AuthorizationClaims,
) (*Account, error) {
// if Account ID is part of the claims
// it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
if claims.DomainCategory != PrivateCategory {
return am.GetAccountByUserOrAccountId(claims.UserId, claims.AccountId, claims.Domain)
} else if claims.AccountId != "" {
accountFromID, err := am.GetAccountById(claims.AccountId)
@@ -714,11 +638,6 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(
}
}
func isDomainValid(domain string) bool {
re := regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
return re.Match([]byte(domain))
}
// AccountExists checks whether account exists (returns true) or not (returns false)
func (am *DefaultAccountManager) AccountExists(accountId string) (*bool, error) {
am.mux.Lock()
@@ -769,33 +688,40 @@ func newAccountWithId(accountId, userId, domain string) *Account {
setupKeys := make(map[string]*SetupKey)
defaultKey := GenerateDefaultSetupKey()
oneOffKey := GenerateSetupKey("One-off key", SetupKeyOneOff, DefaultSetupKeyDuration, []string{})
oneOffKey := GenerateSetupKey("One-off key", SetupKeyOneOff, DefaultSetupKeyDuration)
setupKeys[defaultKey.Key] = defaultKey
setupKeys[oneOffKey.Key] = oneOffKey
network := NewNetwork()
peers := make(map[string]*Peer)
users := make(map[string]*User)
routes := make(map[string]*route.Route)
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
users[userId] = NewAdminUser(userId)
log.Debugf("created new account %s with setup key %s", accountId, defaultKey.Key)
acc := &Account{
Id: accountId,
SetupKeys: setupKeys,
Network: network,
Peers: peers,
Users: users,
CreatedBy: userId,
Domain: domain,
Routes: routes,
NameServerGroups: nameServersGroups,
Id: accountId,
SetupKeys: setupKeys,
Network: network,
Peers: peers,
Users: users,
CreatedBy: userId,
Domain: domain,
Routes: routes,
}
addAllGroup(acc)
return acc
}
func getAccountSetupKeyById(acc *Account, keyId string) *SetupKey {
for _, k := range acc.SetupKeys {
if keyId == k.Id {
return k
}
}
return nil
}
func getAccountSetupKeyByKey(acc *Account, key string) *SetupKey {
for _, k := range acc.SetupKeys {
if key == k.Key {

View File

@@ -127,7 +127,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
}
}
func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
type initUserParams jwtclaims.AuthorizationClaims
type test struct {
@@ -140,7 +140,6 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
expectedMSG string
expectedUserRole UserRole
expectedDomainCategory string
expectedDomain string
expectedPrimaryDomainStatus bool
expectedCreatedBy string
expectedUsers []string
@@ -169,7 +168,6 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleAdmin,
expectedDomainCategory: "",
expectedDomain: publicDomain,
expectedPrimaryDomainStatus: false,
expectedCreatedBy: "pub-domain-user",
expectedUsers: []string{"pub-domain-user"},
@@ -190,7 +188,6 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleAdmin,
expectedDomain: unknownDomain,
expectedDomainCategory: "",
expectedPrimaryDomainStatus: false,
expectedCreatedBy: "unknown-domain-user",
@@ -208,7 +205,6 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleAdmin,
expectedDomain: privateDomain,
expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true,
expectedCreatedBy: "pvt-domain-user",
@@ -231,7 +227,6 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
testingFunc: require.Equal,
expectedMSG: "account IDs should match",
expectedUserRole: UserRoleUser,
expectedDomain: privateDomain,
expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true,
expectedCreatedBy: defaultInitAccount.UserId,
@@ -249,7 +244,6 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
testingFunc: require.Equal,
expectedMSG: "account IDs should match",
expectedUserRole: UserRoleAdmin,
expectedDomain: defaultInitAccount.Domain,
expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true,
expectedCreatedBy: defaultInitAccount.UserId,
@@ -268,32 +262,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
testingFunc: require.Equal,
expectedMSG: "account IDs should match",
expectedUserRole: UserRoleAdmin,
expectedDomain: defaultInitAccount.Domain,
expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true,
expectedCreatedBy: defaultInitAccount.UserId,
expectedUsers: []string{defaultInitAccount.UserId},
}
testCase7 := test{
name: "User With Private Category And Empty Domain",
inputClaims: jwtclaims.AuthorizationClaims{
Domain: "",
UserId: "pvt-domain-user",
DomainCategory: PrivateCategory,
},
inputInitUserParams: defaultInitAccount,
testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleAdmin,
expectedDomain: "",
expectedDomainCategory: "",
expectedPrimaryDomainStatus: false,
expectedCreatedBy: "pvt-domain-user",
expectedUsers: []string{"pvt-domain-user"},
}
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6, testCase7} {
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6} {
t.Run(testCase.name, func(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
@@ -310,7 +284,7 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
testCase.inputClaims.AccountId = initAccount.Id
}
account, err := manager.GetAccountFromToken(testCase.inputClaims)
account, err := manager.GetAccountWithAuthorizationClaims(testCase.inputClaims)
require.NoError(t, err, "support function failed")
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
@@ -320,7 +294,6 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
require.EqualValues(t, testCase.expectedUserRole, account.Users[testCase.inputClaims.UserId].Role, "expected user role should match")
require.EqualValues(t, testCase.expectedDomainCategory, account.DomainCategory, "expected account domain category should match")
require.EqualValues(t, testCase.expectedPrimaryDomainStatus, account.IsDomainPrimaryAccount, "expected account primary status should match")
require.EqualValues(t, testCase.expectedDomain, account.Domain, "expected account domain should match")
})
}
}
@@ -874,7 +847,7 @@ func TestGetUsersFromAccount(t *testing.T) {
account.Users[user.Id] = user
}
userInfos, err := manager.GetUsersFromAccount(accountId, "1")
userInfos, err := manager.GetUsersFromAccount(accountId)
if err != nil {
t.Fatal(err)
}
@@ -962,7 +935,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
if err != nil {
return nil, err
}
return BuildManager(store, NewPeersUpdateManager(), nil, "")
return BuildManager(store, NewPeersUpdateManager(), nil)
}
func createStore(t *testing.T) (Store, error) {

View File

@@ -1,58 +0,0 @@
package server
import (
"fmt"
)
const (
// UserAlreadyExists indicates that user already exists
UserAlreadyExists ErrorType = 1
// AccountNotFound indicates that specified account hasn't been found
AccountNotFound ErrorType = iota
// PreconditionFailed indicates that some pre-condition for the operation hasn't been fulfilled
PreconditionFailed ErrorType = iota
// UserNotFound indicates that user wasn't found in the system (or under a given Account)
UserNotFound ErrorType = iota
// PermissionDenied indicates that user has no permissions to view data
PermissionDenied ErrorType = iota
)
// ErrorType is a type of the Error
type ErrorType int32
// Error is an internal error
type Error struct {
errorType ErrorType
message string
}
// Type returns the Type of the error
func (e *Error) Type() ErrorType {
return e.errorType
}
// Error is an error string
func (e *Error) Error() string {
return e.message
}
// Errorf returns Error(errorType, fmt.Sprintf(format, a...)).
func Errorf(errorType ErrorType, format string, a ...interface{}) error {
return &Error{
errorType: errorType,
message: fmt.Sprintf(format, a...),
}
}
// FromError returns Error, true if the provided error is of type of Error. nil, false otherwise
func FromError(err error) (s *Error, ok bool) {
if err == nil {
return nil, true
}
if e, ok := err.(*Error); ok {
return e, true
}
return nil, false
}

View File

@@ -29,7 +29,6 @@ type FileStore struct {
PeerKeyId2DstRulesId map[string]map[string]struct{} `json:"-"`
PeerKeyID2RouteIDs map[string]map[string]struct{} `json:"-"`
AccountPrefix2RouteIDs map[string]map[string][]string `json:"-"`
InstallationID string
// mutex to synchronise Store read/write operations
mux sync.Mutex `json:"-"`
@@ -416,10 +415,8 @@ func (s *FileStore) GetAccountPeers(accountId string) ([]*Peer, error) {
// GetAllAccounts returns all accounts
func (s *FileStore) GetAllAccounts() (all []*Account) {
s.mux.Lock()
defer s.mux.Unlock()
for _, a := range s.Accounts {
all = append(all, a.Copy())
all = append(all, a)
}
return all
@@ -569,18 +566,3 @@ func (s *FileStore) GetRoutesByPrefix(accountID string, prefix netip.Prefix) ([]
return routes, nil
}
// GetInstallationID returns the installation ID from the store
func (s *FileStore) GetInstallationID() string {
return s.InstallationID
}
// SaveInstallationID saves the installation ID
func (s *FileStore) SaveInstallationID(id string) error {
s.mux.Lock()
defer s.mux.Unlock()
s.InstallationID = id
return s.persist(s.storeFile)
}

View File

@@ -3,7 +3,6 @@ package server
import (
"context"
"fmt"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
gPeer "google.golang.org/grpc/peer"
"strings"
@@ -18,7 +17,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
gRPCPeer "google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
)
@@ -31,12 +29,10 @@ type GRPCServer struct {
config *Config
turnCredentialsManager TURNCredentialsManager
jwtMiddleware *middleware.JWTMiddleware
appMetrics telemetry.AppMetrics
}
// NewServer creates a new Management server
func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager,
turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics) (*GRPCServer, error) {
func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager) (*GRPCServer, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
@@ -56,16 +52,6 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
log.Debug("unable to use http config to create new jwt middleware")
}
if appMetrics != nil {
// update gauge based on number of connected peers which is equal to open gRPC streams
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
return int64(len(peersUpdateManager.peerChannels))
})
if err != nil {
return nil, err
}
}
return &GRPCServer{
wgKey: key,
// peerKey -> event channel
@@ -74,15 +60,11 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
config: config,
turnCredentialsManager: turnCredentialsManager,
jwtMiddleware: jwtMiddleware,
appMetrics: appMetrics,
}, nil
}
func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
// todo introduce something more meaningful with the key expiration/rotation
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountGetKeyRequest()
}
now := time.Now().Add(24 * time.Hour)
secs := int64(now.Second())
nanos := int32(now.Nanosecond())
@@ -97,13 +79,7 @@ func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
// notifies the connected peer of any updates (e.g. new peers under the same account)
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequest()
}
p, ok := gRPCPeer.FromContext(srv.Context())
if ok {
log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, p.Addr.String())
}
log.Debugf("Sync request from peer %s", req.WgPubKey)
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
@@ -201,7 +177,7 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
return nil, status.Errorf(codes.Internal, "invalid jwt token, err: %v", err)
}
claims := jwtclaims.ExtractClaimsWithToken(token, s.config.HttpConfig.AuthAudience)
_, err = s.accountManager.GetAccountFromToken(claims)
_, err = s.accountManager.GetAccountWithAuthorizationClaims(claims)
if err != nil {
return nil, status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
}
@@ -279,13 +255,7 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
// In case of the successful registration login is also successful
func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequest()
}
p, ok := gRPCPeer.FromContext(ctx)
if ok {
log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, p.Addr.String())
}
log.Debugf("Login request from peer %s", req.WgPubKey)
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {

View File

@@ -3,5 +3,3 @@ generate:
models: true
embedded-spec: false
output: types.gen.go
compatibility:
always-prefix-enum-values: true

View File

@@ -11,6 +11,6 @@ fi
old_pwd=$(pwd)
script_path=$(dirname $(realpath "$0"))
cd "$script_path"
go install github.com/deepmap/oapi-codegen/cmd/oapi-codegen@4a1477f6a8ba6ca8115cc23bb2fb67f0b9fca18e
go install github.com/deepmap/oapi-codegen/cmd/oapi-codegen@v1.11.0
oapi-codegen --config cfg.yaml openapi.yml
cd "$old_pwd"

View File

@@ -16,8 +16,6 @@ tags:
description: Interact with and view information about rules.
- name: Routes
description: Interact with and view information about routes.
- name: DNS
description: Interact with and view information about DNS configuration.
components:
schemas:
User:
@@ -33,59 +31,13 @@ components:
description: User's name from idp provider
type: string
role:
description: User's NetBird account role
description: User's Netbird account role
type: string
status:
description: User's status
type: string
enum: [ "active","invited","disabled" ]
auto_groups:
description: Groups to auto-assign to peers registered by this user
type: array
items:
type: string
required:
- id
- email
- name
- role
- auto_groups
- status
UserRequest:
type: object
properties:
role:
description: User's NetBird account role
type: string
auto_groups:
description: Groups to auto-assign to peers registered by this user
type: array
items:
type: string
required:
- role
- auto_groups
UserCreateRequest:
type: object
properties:
role:
description: User's NetBird account role
type: string
email:
description: User's Email to send invite to
type: string
name:
description: User's full name
type: string
auto_groups:
description: Groups to auto-assign to peers registered by this user
type: array
items:
type: string
required:
- role
- auto_groups
- email
PeerMinimum:
type: object
properties:
@@ -124,18 +76,20 @@ components:
type: array
items:
$ref: '#/components/schemas/GroupMinimum'
activated_by:
description: Provides information of who activated the Peer. User or Setup Key
type: object
properties:
type:
type: string
value:
type: string
required:
- type
- value
ssh_enabled:
description: Indicates whether SSH server is enabled on this peer
type: boolean
user_id:
description: User ID of the user that enrolled this peer
type: string
hostname:
description: Hostname of the machine
type: string
ui_version:
description: Peer's desktop UI version
type: string
required:
- ip
- connected
@@ -143,8 +97,8 @@ components:
- os
- version
- groups
- activated_by
- ssh_enabled
- hostname
SetupKey:
type: object
properties:
@@ -180,15 +134,6 @@ components:
state:
description: Setup key status, "valid", "overused","expired" or "revoked"
type: string
auto_groups:
description: Setup key groups to auto-assign to peers registered with this key
type: array
items:
type: string
updated_at:
description: Setup key last update date
type: string
format: date-time
required:
- id
- key
@@ -200,8 +145,6 @@ components:
- used_times
- last_used
- state
- auto_groups
- updated_at
SetupKeyRequest:
type: object
properties:
@@ -217,17 +160,11 @@ components:
revoked:
description: Setup key revocation status
type: boolean
auto_groups:
description: Setup key groups to auto-assign to peers registered with this key
type: array
items:
type: string
required:
- name
- type
- expires_in
- revoked
- auto_groups
GroupMinimum:
type: object
properties:
@@ -401,88 +338,6 @@ components:
enum: [ "network","network_id","description","enabled","peer","metric","masquerade" ]
required:
- path
Nameserver:
type: object
properties:
ip:
description: Nameserver IP
type: string
ns_type:
description: Nameserver Type
type: string
enum: ["udp"]
port:
description: Nameserver Port
type: integer
required:
- ip
- ns_type
- port
NameserverGroupRequest:
type: object
properties:
name:
description: Nameserver group name
type: string
maxLength: 40
minLength: 1
description:
description: Nameserver group description
type: string
nameservers:
description: Nameserver group
minLength: 1
maxLength: 2
type: array
items:
$ref: '#/components/schemas/Nameserver'
enabled:
description: Nameserver group status
type: boolean
groups:
description: Nameserver group tag groups
type: array
items:
type: string
primary:
description: Nameserver group primary status
type: boolean
domains:
description: Nameserver group domain list
type: array
items:
type: string
minLength: 1
maxLength: 255
required:
- name
- description
- nameservers
- enabled
- groups
- primary
- domains
NameserverGroup:
allOf:
- type: object
properties:
id:
description: Nameserver group ID
type: string
required:
- id
- $ref: '#/components/schemas/NameserverGroupRequest'
NameserverGroupPatchOperation:
allOf:
- $ref: '#/components/schemas/PatchMinimum'
- type: object
properties:
path:
description: Nameserver group field to update in form /<field>
type: string
enum: [ "name","description","enabled","groups","nameservers" ]
required:
- path
responses:
not_found:
@@ -537,67 +392,6 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/users/:
post:
summary: Create a User (invite)
tags: [ Users]
security:
- BearerAuth: [ ]
requestBody:
description: User invite information
content:
'application/json':
schema:
$ref: '#/components/schemas/UserCreateRequest'
responses:
'200':
description: A User object
content:
application/json:
schema:
$ref: '#/components/schemas/User'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/users/{id}:
put:
summary: Update information about a User
tags: [ Users]
security:
- BearerAuth: [ ]
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The User ID
requestBody:
description: User update
content:
'application/json':
schema:
$ref: '#/components/schemas/UserRequest'
responses:
'200':
description: A User object
content:
application/json:
schema:
$ref: '#/components/schemas/User'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/peers:
get:
summary: Returns a list of all peers
@@ -1375,176 +1169,6 @@ paths:
schema:
type: string
description: The Route ID
responses:
'200':
description: Delete status code
content: { }
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/dns/nameservers:
get:
summary: Returns a list of all Nameserver Groups
tags: [ DNS ]
security:
- BearerAuth: [ ]
responses:
'200':
description: A JSON Array of Nameserver Groups
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/NameserverGroup'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
post:
summary: Creates a Nameserver Group
tags: [ DNS ]
security:
- BearerAuth: [ ]
requestBody:
description: New Nameserver Groups request
content:
'application/json':
schema:
$ref: '#/components/schemas/NameserverGroupRequest'
responses:
'200':
description: A Nameserver Groups Object
content:
application/json:
schema:
$ref: '#/components/schemas/NameserverGroup'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/dns/nameservers/{id}:
get:
summary: Get information about a Nameserver Groups
tags: [ DNS ]
security:
- BearerAuth: [ ]
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The Nameserver Group ID
responses:
'200':
description: A Nameserver Group object
content:
application/json:
schema:
$ref: '#/components/schemas/NameserverGroup'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
put:
summary: Update/Replace a Nameserver Group
tags: [ DNS ]
security:
- BearerAuth: [ ]
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The Nameserver Group ID
requestBody:
description: Update Nameserver Group request
content:
application/json:
schema:
$ref: '#/components/schemas/NameserverGroupRequest'
responses:
'200':
description: A Nameserver Group object
content:
application/json:
schema:
$ref: '#/components/schemas/NameserverGroup'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
patch:
summary: Update information about a Nameserver Group
tags: [ DNS ]
security:
- BearerAuth: [ ]
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The Nameserver Group ID
requestBody:
description: Update Nameserver Group request using a list of json patch objects
content:
'application/json':
schema:
type: array
items:
$ref: '#/components/schemas/NameserverGroupPatchOperation'
responses:
'200':
description: A Nameserver Group object
content:
application/json:
schema:
$ref: '#/components/schemas/NameserverGroup'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
delete:
summary: Delete a Nameserver Group
tags: [ DNS ]
security:
- BearerAuth: [ ]
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The Nameserver Group ID
responses:
'200':
description: Delete status code

View File

@@ -1,6 +1,6 @@
// Package api provides primitives to interact with the openapi HTTP API.
//
// Code generated by github.com/deepmap/oapi-codegen version v1.11.1-0.20220912230023-4a1477f6a8ba DO NOT EDIT.
// Code generated by github.com/deepmap/oapi-codegen version v1.11.0 DO NOT EDIT.
package api
import (
@@ -24,27 +24,6 @@ const (
GroupPatchOperationPathPeers GroupPatchOperationPath = "peers"
)
// Defines values for NameserverNsType.
const (
NameserverNsTypeUdp NameserverNsType = "udp"
)
// Defines values for NameserverGroupPatchOperationOp.
const (
NameserverGroupPatchOperationOpAdd NameserverGroupPatchOperationOp = "add"
NameserverGroupPatchOperationOpRemove NameserverGroupPatchOperationOp = "remove"
NameserverGroupPatchOperationOpReplace NameserverGroupPatchOperationOp = "replace"
)
// Defines values for NameserverGroupPatchOperationPath.
const (
NameserverGroupPatchOperationPathDescription NameserverGroupPatchOperationPath = "description"
NameserverGroupPatchOperationPathEnabled NameserverGroupPatchOperationPath = "enabled"
NameserverGroupPatchOperationPathGroups NameserverGroupPatchOperationPath = "groups"
NameserverGroupPatchOperationPathName NameserverGroupPatchOperationPath = "name"
NameserverGroupPatchOperationPathNameservers NameserverGroupPatchOperationPath = "nameservers"
)
// Defines values for PatchMinimumOp.
const (
PatchMinimumOpAdd PatchMinimumOp = "add"
@@ -87,439 +66,300 @@ const (
RulePatchOperationPathSources RulePatchOperationPath = "sources"
)
// Defines values for UserStatus.
const (
UserStatusActive UserStatus = "active"
UserStatusDisabled UserStatus = "disabled"
UserStatusInvited UserStatus = "invited"
)
// Group defines model for Group.
type Group struct {
// Id Group ID
// Group ID
Id string `json:"id"`
// Name Group Name identifier
// Group Name identifier
Name string `json:"name"`
// Peers List of peers object
// List of peers object
Peers []PeerMinimum `json:"peers"`
// PeersCount Count of peers associated to the group
// Count of peers associated to the group
PeersCount int `json:"peers_count"`
}
// GroupMinimum defines model for GroupMinimum.
type GroupMinimum struct {
// Id Group ID
// Group ID
Id string `json:"id"`
// Name Group Name identifier
// Group Name identifier
Name string `json:"name"`
// PeersCount Count of peers associated to the group
// Count of peers associated to the group
PeersCount int `json:"peers_count"`
}
// GroupPatchOperation defines model for GroupPatchOperation.
type GroupPatchOperation struct {
// Op Patch operation type
// Patch operation type
Op GroupPatchOperationOp `json:"op"`
// Path Group field to update in form /<field>
// Group field to update in form /<field>
Path GroupPatchOperationPath `json:"path"`
// Value Values to be applied
// Values to be applied
Value []string `json:"value"`
}
// GroupPatchOperationOp Patch operation type
// Patch operation type
type GroupPatchOperationOp string
// GroupPatchOperationPath Group field to update in form /<field>
// Group field to update in form /<field>
type GroupPatchOperationPath string
// Nameserver defines model for Nameserver.
type Nameserver struct {
// Ip Nameserver IP
Ip string `json:"ip"`
// NsType Nameserver Type
NsType NameserverNsType `json:"ns_type"`
// Port Nameserver Port
Port int `json:"port"`
}
// NameserverNsType Nameserver Type
type NameserverNsType string
// NameserverGroup defines model for NameserverGroup.
type NameserverGroup struct {
// Description Nameserver group description
Description string `json:"description"`
// Domains Nameserver group domain list
Domains []string `json:"domains"`
// Enabled Nameserver group status
Enabled bool `json:"enabled"`
// Groups Nameserver group tag groups
Groups []string `json:"groups"`
// Id Nameserver group ID
Id string `json:"id"`
// Name Nameserver group name
Name string `json:"name"`
// Nameservers Nameserver group
Nameservers []Nameserver `json:"nameservers"`
// Primary Nameserver group primary status
Primary bool `json:"primary"`
}
// NameserverGroupPatchOperation defines model for NameserverGroupPatchOperation.
type NameserverGroupPatchOperation struct {
// Op Patch operation type
Op NameserverGroupPatchOperationOp `json:"op"`
// Path Nameserver group field to update in form /<field>
Path NameserverGroupPatchOperationPath `json:"path"`
// Value Values to be applied
Value []string `json:"value"`
}
// NameserverGroupPatchOperationOp Patch operation type
type NameserverGroupPatchOperationOp string
// NameserverGroupPatchOperationPath Nameserver group field to update in form /<field>
type NameserverGroupPatchOperationPath string
// NameserverGroupRequest defines model for NameserverGroupRequest.
type NameserverGroupRequest struct {
// Description Nameserver group description
Description string `json:"description"`
// Domains Nameserver group domain list
Domains []string `json:"domains"`
// Enabled Nameserver group status
Enabled bool `json:"enabled"`
// Groups Nameserver group tag groups
Groups []string `json:"groups"`
// Name Nameserver group name
Name string `json:"name"`
// Nameservers Nameserver group
Nameservers []Nameserver `json:"nameservers"`
// Primary Nameserver group primary status
Primary bool `json:"primary"`
}
// PatchMinimum defines model for PatchMinimum.
type PatchMinimum struct {
// Op Patch operation type
// Patch operation type
Op PatchMinimumOp `json:"op"`
// Value Values to be applied
// Values to be applied
Value []string `json:"value"`
}
// PatchMinimumOp Patch operation type
// Patch operation type
type PatchMinimumOp string
// Peer defines model for Peer.
type Peer struct {
// Connected Peer to Management connection status
// Provides information of who activated the Peer. User or Setup Key
ActivatedBy struct {
Type string `json:"type"`
Value string `json:"value"`
} `json:"activated_by"`
// Peer to Management connection status
Connected bool `json:"connected"`
// Groups Groups that the peer belongs to
// Groups that the peer belongs to
Groups []GroupMinimum `json:"groups"`
// Hostname Hostname of the machine
Hostname string `json:"hostname"`
// Id Peer ID
// Peer ID
Id string `json:"id"`
// Ip Peer's IP address
// Peer's IP address
Ip string `json:"ip"`
// LastSeen Last time peer connected to Netbird's management service
// Last time peer connected to Netbird's management service
LastSeen time.Time `json:"last_seen"`
// Name Peer's hostname
// Peer's hostname
Name string `json:"name"`
// Os Peer's operating system and version
// Peer's operating system and version
Os string `json:"os"`
// SshEnabled Indicates whether SSH server is enabled on this peer
// Indicates whether SSH server is enabled on this peer
SshEnabled bool `json:"ssh_enabled"`
// UiVersion Peer's desktop UI version
UiVersion *string `json:"ui_version,omitempty"`
// UserId User ID of the user that enrolled this peer
UserId *string `json:"user_id,omitempty"`
// Version Peer's daemon or cli version
// Peer's daemon or cli version
Version string `json:"version"`
}
// PeerMinimum defines model for PeerMinimum.
type PeerMinimum struct {
// Id Peer ID
// Peer ID
Id string `json:"id"`
// Name Peer's hostname
// Peer's hostname
Name string `json:"name"`
}
// Route defines model for Route.
type Route struct {
// Description Route description
// Route description
Description string `json:"description"`
// Enabled Route status
// Route status
Enabled bool `json:"enabled"`
// Id Route Id
// Route Id
Id string `json:"id"`
// Masquerade Indicate if peer should masquerade traffic to this route's prefix
// Indicate if peer should masquerade traffic to this route's prefix
Masquerade bool `json:"masquerade"`
// Metric Route metric number. Lowest number has higher priority
// Route metric number. Lowest number has higher priority
Metric int `json:"metric"`
// Network Network range in CIDR format
// Network range in CIDR format
Network string `json:"network"`
// NetworkId Route network identifier, to group HA routes
// Route network identifier, to group HA routes
NetworkId string `json:"network_id"`
// NetworkType Network type indicating if it is IPv4 or IPv6
// Network type indicating if it is IPv4 or IPv6
NetworkType string `json:"network_type"`
// Peer Peer Identifier associated with route
// Peer Identifier associated with route
Peer string `json:"peer"`
}
// RoutePatchOperation defines model for RoutePatchOperation.
type RoutePatchOperation struct {
// Op Patch operation type
// Patch operation type
Op RoutePatchOperationOp `json:"op"`
// Path Route field to update in form /<field>
// Route field to update in form /<field>
Path RoutePatchOperationPath `json:"path"`
// Value Values to be applied
// Values to be applied
Value []string `json:"value"`
}
// RoutePatchOperationOp Patch operation type
// Patch operation type
type RoutePatchOperationOp string
// RoutePatchOperationPath Route field to update in form /<field>
// Route field to update in form /<field>
type RoutePatchOperationPath string
// RouteRequest defines model for RouteRequest.
type RouteRequest struct {
// Description Route description
// Route description
Description string `json:"description"`
// Enabled Route status
// Route status
Enabled bool `json:"enabled"`
// Masquerade Indicate if peer should masquerade traffic to this route's prefix
// Indicate if peer should masquerade traffic to this route's prefix
Masquerade bool `json:"masquerade"`
// Metric Route metric number. Lowest number has higher priority
// Route metric number. Lowest number has higher priority
Metric int `json:"metric"`
// Network Network range in CIDR format
// Network range in CIDR format
Network string `json:"network"`
// NetworkId Route network identifier, to group HA routes
// Route network identifier, to group HA routes
NetworkId string `json:"network_id"`
// Peer Peer Identifier associated with route
// Peer Identifier associated with route
Peer string `json:"peer"`
}
// Rule defines model for Rule.
type Rule struct {
// Description Rule friendly description
// Rule friendly description
Description string `json:"description"`
// Destinations Rule destination groups
// Rule destination groups
Destinations []GroupMinimum `json:"destinations"`
// Disabled Rules status
// Rules status
Disabled bool `json:"disabled"`
// Flow Rule flow, currently, only "bidirect" for bi-directional traffic is accepted
// Rule flow, currently, only "bidirect" for bi-directional traffic is accepted
Flow string `json:"flow"`
// Id Rule ID
// Rule ID
Id string `json:"id"`
// Name Rule name identifier
// Rule name identifier
Name string `json:"name"`
// Sources Rule source groups
// Rule source groups
Sources []GroupMinimum `json:"sources"`
}
// RuleMinimum defines model for RuleMinimum.
type RuleMinimum struct {
// Description Rule friendly description
// Rule friendly description
Description string `json:"description"`
// Disabled Rules status
// Rules status
Disabled bool `json:"disabled"`
// Flow Rule flow, currently, only "bidirect" for bi-directional traffic is accepted
// Rule flow, currently, only "bidirect" for bi-directional traffic is accepted
Flow string `json:"flow"`
// Name Rule name identifier
// Rule name identifier
Name string `json:"name"`
}
// RulePatchOperation defines model for RulePatchOperation.
type RulePatchOperation struct {
// Op Patch operation type
// Patch operation type
Op RulePatchOperationOp `json:"op"`
// Path Rule field to update in form /<field>
// Rule field to update in form /<field>
Path RulePatchOperationPath `json:"path"`
// Value Values to be applied
// Values to be applied
Value []string `json:"value"`
}
// RulePatchOperationOp Patch operation type
// Patch operation type
type RulePatchOperationOp string
// RulePatchOperationPath Rule field to update in form /<field>
// Rule field to update in form /<field>
type RulePatchOperationPath string
// SetupKey defines model for SetupKey.
type SetupKey struct {
// AutoGroups Setup key groups to auto-assign to peers registered with this key
AutoGroups []string `json:"auto_groups"`
// Expires Setup Key expiration date
// Setup Key expiration date
Expires time.Time `json:"expires"`
// Id Setup Key ID
// Setup Key ID
Id string `json:"id"`
// Key Setup Key value
// Setup Key value
Key string `json:"key"`
// LastUsed Setup key last usage date
// Setup key last usage date
LastUsed time.Time `json:"last_used"`
// Name Setup key name identifier
// Setup key name identifier
Name string `json:"name"`
// Revoked Setup key revocation status
// Setup key revocation status
Revoked bool `json:"revoked"`
// State Setup key status, "valid", "overused","expired" or "revoked"
// Setup key status, "valid", "overused","expired" or "revoked"
State string `json:"state"`
// Type Setup key type, one-off for single time usage and reusable
// Setup key type, one-off for single time usage and reusable
Type string `json:"type"`
// UpdatedAt Setup key last update date
UpdatedAt time.Time `json:"updated_at"`
// UsedTimes Usage count of setup key
// Usage count of setup key
UsedTimes int `json:"used_times"`
// Valid Setup key validity status
// Setup key validity status
Valid bool `json:"valid"`
}
// SetupKeyRequest defines model for SetupKeyRequest.
type SetupKeyRequest struct {
// AutoGroups Setup key groups to auto-assign to peers registered with this key
AutoGroups []string `json:"auto_groups"`
// ExpiresIn Expiration time in seconds
// Expiration time in seconds
ExpiresIn int `json:"expires_in"`
// Name Setup Key name
// Setup Key name
Name string `json:"name"`
// Revoked Setup key revocation status
// Setup key revocation status
Revoked bool `json:"revoked"`
// Type Setup key type, one-off for single time usage and reusable
// Setup key type, one-off for single time usage and reusable
Type string `json:"type"`
}
// User defines model for User.
type User struct {
// AutoGroups Groups to auto-assign to peers registered by this user
AutoGroups []string `json:"auto_groups"`
// Email User's email address
// User's email address
Email string `json:"email"`
// Id User ID
// User ID
Id string `json:"id"`
// Name User's name from idp provider
// User's name from idp provider
Name string `json:"name"`
// Role User's NetBird account role
Role string `json:"role"`
// Status User's status
Status UserStatus `json:"status"`
}
// UserStatus User's status
type UserStatus string
// UserCreateRequest defines model for UserCreateRequest.
type UserCreateRequest struct {
// AutoGroups Groups to auto-assign to peers registered by this user
AutoGroups []string `json:"auto_groups"`
// Email User's Email to send invite to
Email string `json:"email"`
// Name User's full name
Name *string `json:"name,omitempty"`
// Role User's NetBird account role
// User's Netbird account role
Role string `json:"role"`
}
// UserRequest defines model for UserRequest.
type UserRequest struct {
// AutoGroups Groups to auto-assign to peers registered by this user
AutoGroups []string `json:"auto_groups"`
// Role User's NetBird account role
Role string `json:"role"`
}
// PatchApiDnsNameserversIdJSONBody defines parameters for PatchApiDnsNameserversId.
type PatchApiDnsNameserversIdJSONBody = []NameserverGroupPatchOperation
// PostApiGroupsJSONBody defines parameters for PostApiGroups.
type PostApiGroupsJSONBody struct {
Name string `json:"name"`
@@ -541,22 +381,28 @@ type PutApiPeersIdJSONBody struct {
SshEnabled bool `json:"ssh_enabled"`
}
// PostApiRoutesJSONBody defines parameters for PostApiRoutes.
type PostApiRoutesJSONBody = RouteRequest
// PatchApiRoutesIdJSONBody defines parameters for PatchApiRoutesId.
type PatchApiRoutesIdJSONBody = []RoutePatchOperation
// PutApiRoutesIdJSONBody defines parameters for PutApiRoutesId.
type PutApiRoutesIdJSONBody = RouteRequest
// PostApiRulesJSONBody defines parameters for PostApiRules.
type PostApiRulesJSONBody struct {
// Description Rule friendly description
// Rule friendly description
Description string `json:"description"`
Destinations *[]string `json:"destinations,omitempty"`
// Disabled Rules status
// Rules status
Disabled bool `json:"disabled"`
// Flow Rule flow, currently, only "bidirect" for bi-directional traffic is accepted
// Rule flow, currently, only "bidirect" for bi-directional traffic is accepted
Flow string `json:"flow"`
// Name Rule name identifier
// Rule name identifier
Name string `json:"name"`
Sources *[]string `json:"sources,omitempty"`
}
@@ -566,29 +412,26 @@ type PatchApiRulesIdJSONBody = []RulePatchOperation
// PutApiRulesIdJSONBody defines parameters for PutApiRulesId.
type PutApiRulesIdJSONBody struct {
// Description Rule friendly description
// Rule friendly description
Description string `json:"description"`
Destinations *[]string `json:"destinations,omitempty"`
// Disabled Rules status
// Rules status
Disabled bool `json:"disabled"`
// Flow Rule flow, currently, only "bidirect" for bi-directional traffic is accepted
// Rule flow, currently, only "bidirect" for bi-directional traffic is accepted
Flow string `json:"flow"`
// Name Rule name identifier
// Rule name identifier
Name string `json:"name"`
Sources *[]string `json:"sources,omitempty"`
}
// PostApiDnsNameserversJSONRequestBody defines body for PostApiDnsNameservers for application/json ContentType.
type PostApiDnsNameserversJSONRequestBody = NameserverGroupRequest
// PostApiSetupKeysJSONBody defines parameters for PostApiSetupKeys.
type PostApiSetupKeysJSONBody = SetupKeyRequest
// PatchApiDnsNameserversIdJSONRequestBody defines body for PatchApiDnsNameserversId for application/json ContentType.
type PatchApiDnsNameserversIdJSONRequestBody = PatchApiDnsNameserversIdJSONBody
// PutApiDnsNameserversIdJSONRequestBody defines body for PutApiDnsNameserversId for application/json ContentType.
type PutApiDnsNameserversIdJSONRequestBody = NameserverGroupRequest
// PutApiSetupKeysIdJSONBody defines parameters for PutApiSetupKeysId.
type PutApiSetupKeysIdJSONBody = SetupKeyRequest
// PostApiGroupsJSONRequestBody defines body for PostApiGroups for application/json ContentType.
type PostApiGroupsJSONRequestBody PostApiGroupsJSONBody
@@ -603,13 +446,13 @@ type PutApiGroupsIdJSONRequestBody PutApiGroupsIdJSONBody
type PutApiPeersIdJSONRequestBody PutApiPeersIdJSONBody
// PostApiRoutesJSONRequestBody defines body for PostApiRoutes for application/json ContentType.
type PostApiRoutesJSONRequestBody = RouteRequest
type PostApiRoutesJSONRequestBody = PostApiRoutesJSONBody
// PatchApiRoutesIdJSONRequestBody defines body for PatchApiRoutesId for application/json ContentType.
type PatchApiRoutesIdJSONRequestBody = PatchApiRoutesIdJSONBody
// PutApiRoutesIdJSONRequestBody defines body for PutApiRoutesId for application/json ContentType.
type PutApiRoutesIdJSONRequestBody = RouteRequest
type PutApiRoutesIdJSONRequestBody = PutApiRoutesIdJSONBody
// PostApiRulesJSONRequestBody defines body for PostApiRules for application/json ContentType.
type PostApiRulesJSONRequestBody PostApiRulesJSONBody
@@ -621,13 +464,7 @@ type PatchApiRulesIdJSONRequestBody = PatchApiRulesIdJSONBody
type PutApiRulesIdJSONRequestBody PutApiRulesIdJSONBody
// PostApiSetupKeysJSONRequestBody defines body for PostApiSetupKeys for application/json ContentType.
type PostApiSetupKeysJSONRequestBody = SetupKeyRequest
type PostApiSetupKeysJSONRequestBody = PostApiSetupKeysJSONBody
// PutApiSetupKeysIdJSONRequestBody defines body for PutApiSetupKeysId for application/json ContentType.
type PutApiSetupKeysIdJSONRequestBody = SetupKeyRequest
// PostApiUsersJSONRequestBody defines body for PostApiUsers for application/json ContentType.
type PostApiUsersJSONRequestBody = UserCreateRequest
// PutApiUsersIdJSONRequestBody defines body for PutApiUsersId for application/json ContentType.
type PutApiUsersIdJSONRequestBody = UserRequest
type PutApiSetupKeysIdJSONRequestBody = PutApiSetupKeysIdJSONBody

View File

@@ -33,7 +33,7 @@ func NewGroups(accountManager server.AccountManager, authAudience string) *Group
// GetAllGroupsHandler list for the account
func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -50,7 +50,7 @@ func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
// UpdateGroupHandler handles update to a group identified by a given ID
func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -111,7 +111,7 @@ func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
// PatchGroupHandler handles patch updates to a group identified by a given ID
func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -226,7 +226,7 @@ func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
// CreateGroupHandler handles group creation request
func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -260,7 +260,7 @@ func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
// DeleteGroupHandler handles group deletion request
func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -295,7 +295,7 @@ func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
// GetGroupHandler returns a group
func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return

View File

@@ -25,7 +25,7 @@ var TestPeers = map[string]*server.Peer{
"B": &server.Peer{Key: "B", IP: net.ParseIP("200.200.200.200")},
}
func initGroupTestData(user *server.User, groups ...*server.Group) *Groups {
func initGroupTestData(groups ...*server.Group) *Groups {
return &Groups{
accountManager: &mock_server.MockAccountManager{
SaveGroupFunc: func(accountID string, group *server.Group) error {
@@ -67,17 +67,14 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *Groups {
}
return nil, fmt.Errorf("peer not found")
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
Peers: TestPeers,
Users: map[string]*server.User{
user.Id: user,
},
Groups: map[string]*server.Group{
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}},
"id-all": {ID: "id-all", Name: "All"}},
"id-existed": &server.Group{ID: "id-existed", Peers: []string{"A", "B"}},
"id-all": &server.Group{ID: "id-all", Name: "All"}},
}, nil
},
},
@@ -123,8 +120,7 @@ func TestGetGroup(t *testing.T) {
Name: "Group",
}
adminUser := server.NewAdminUser("test_user")
p := initGroupTestData(adminUser, group)
p := initGroupTestData(group)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@@ -274,8 +270,7 @@ func TestWriteGroup(t *testing.T) {
},
}
adminUser := server.NewAdminUser("test_user")
p := initGroupTestData(adminUser)
p := initGroupTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {

View File

@@ -4,14 +4,12 @@ import (
"github.com/gorilla/mux"
s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/rs/cors"
"net/http"
)
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience string, authKeysLocation string,
appMetrics telemetry.AppMetrics) (http.Handler, error) {
func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience string, authKeysLocation string) (http.Handler, error) {
jwtMiddleware, err := middleware.NewJwtMiddleware(
authIssuer,
authAudience,
@@ -23,15 +21,12 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
corsMiddleware := cors.AllowAll()
acMiddleware := middleware.NewAccessControl(
acMiddleware := middleware.NewAccessControll(
authAudience,
accountManager.IsUserAdmin)
rootRouter := mux.NewRouter()
metricsMiddleware := appMetrics.HTTPMiddleware()
apiHandler := rootRouter.PathPrefix("/api").Subrouter()
apiHandler.Use(metricsMiddleware.Handler, corsMiddleware.Handler, jwtMiddleware.Handler, acMiddleware.Handler)
apiHandler := mux.NewRouter()
apiHandler.Use(corsMiddleware.Handler, jwtMiddleware.Handler, acMiddleware.Handler)
groupsHandler := NewGroups(accountManager, authAudience)
rulesHandler := NewRules(accountManager, authAudience)
@@ -39,69 +34,39 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
keysHandler := NewSetupKeysHandler(accountManager, authAudience)
userHandler := NewUserHandler(accountManager, authAudience)
routesHandler := NewRoutes(accountManager, authAudience)
nameserversHandler := NewNameservers(accountManager, authAudience)
apiHandler.HandleFunc("/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/peers/{id}", peersHandler.HandlePeer).
apiHandler.HandleFunc("/api/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/peers/{id}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
apiHandler.HandleFunc("/users", userHandler.GetUsers).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/users/{id}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/users", userHandler.CreateUserHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/users", userHandler.GetUsers).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys", keysHandler.GetKeys).Methods("GET", "POST", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys/{id}", keysHandler.HandleKey).Methods("GET", "PUT", "OPTIONS")
apiHandler.HandleFunc("/setup-keys", keysHandler.GetAllSetupKeysHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/setup-keys", keysHandler.CreateSetupKeyHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/setup-keys/{id}", keysHandler.GetSetupKeyHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/setup-keys/{id}", keysHandler.UpdateSetupKeyHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys", keysHandler.GetKeys).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys/{id}", keysHandler.HandleKey).
Methods("GET", "PUT", "DELETE", "OPTIONS")
apiHandler.HandleFunc("/rules", rulesHandler.GetAllRulesHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/rules", rulesHandler.CreateRuleHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/rules/{id}", rulesHandler.UpdateRuleHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/rules/{id}", rulesHandler.PatchRuleHandler).Methods("PATCH", "OPTIONS")
apiHandler.HandleFunc("/rules/{id}", rulesHandler.GetRuleHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/rules/{id}", rulesHandler.DeleteRuleHandler).Methods("DELETE", "OPTIONS")
apiHandler.HandleFunc("/api/rules", rulesHandler.GetAllRulesHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/rules", rulesHandler.CreateRuleHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/rules/{id}", rulesHandler.UpdateRuleHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/rules/{id}", rulesHandler.PatchRuleHandler).Methods("PATCH", "OPTIONS")
apiHandler.HandleFunc("/api/rules/{id}", rulesHandler.GetRuleHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/rules/{id}", rulesHandler.DeleteRuleHandler).Methods("DELETE", "OPTIONS")
apiHandler.HandleFunc("/groups", groupsHandler.GetAllGroupsHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/groups", groupsHandler.CreateGroupHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/groups/{id}", groupsHandler.UpdateGroupHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/groups/{id}", groupsHandler.PatchGroupHandler).Methods("PATCH", "OPTIONS")
apiHandler.HandleFunc("/groups/{id}", groupsHandler.GetGroupHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/groups/{id}", groupsHandler.DeleteGroupHandler).Methods("DELETE", "OPTIONS")
apiHandler.HandleFunc("/api/groups", groupsHandler.GetAllGroupsHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/groups", groupsHandler.CreateGroupHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/groups/{id}", groupsHandler.UpdateGroupHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/groups/{id}", groupsHandler.PatchGroupHandler).Methods("PATCH", "OPTIONS")
apiHandler.HandleFunc("/api/groups/{id}", groupsHandler.GetGroupHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/groups/{id}", groupsHandler.DeleteGroupHandler).Methods("DELETE", "OPTIONS")
apiHandler.HandleFunc("/routes", routesHandler.GetAllRoutesHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/routes", routesHandler.CreateRouteHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/routes/{id}", routesHandler.UpdateRouteHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/routes/{id}", routesHandler.PatchRouteHandler).Methods("PATCH", "OPTIONS")
apiHandler.HandleFunc("/routes/{id}", routesHandler.GetRouteHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/routes/{id}", routesHandler.DeleteRouteHandler).Methods("DELETE", "OPTIONS")
apiHandler.HandleFunc("/api/routes", routesHandler.GetAllRoutesHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/routes", routesHandler.CreateRouteHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/routes/{id}", routesHandler.UpdateRouteHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/routes/{id}", routesHandler.PatchRouteHandler).Methods("PATCH", "OPTIONS")
apiHandler.HandleFunc("/api/routes/{id}", routesHandler.GetRouteHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/routes/{id}", routesHandler.DeleteRouteHandler).Methods("DELETE", "OPTIONS")
apiHandler.HandleFunc("/dns/nameservers", nameserversHandler.GetAllNameserversHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/dns/nameservers", nameserversHandler.CreateNameserverGroupHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/dns/nameservers/{id}", nameserversHandler.UpdateNameserverGroupHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/dns/nameservers/{id}", nameserversHandler.PatchNameserverGroupHandler).Methods("PATCH", "OPTIONS")
apiHandler.HandleFunc("/dns/nameservers/{id}", nameserversHandler.GetNameserverGroupHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/dns/nameservers/{id}", nameserversHandler.DeleteNameserverGroupHandler).Methods("DELETE", "OPTIONS")
err = apiHandler.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
methods, err := route.GetMethods()
if err != nil {
return err
}
for _, method := range methods {
template, err := route.GetPathTemplate()
if err != nil {
return err
}
err = metricsMiddleware.AddHTTPRequestResponseCounter(template, method)
if err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
return rootRouter, nil
return apiHandler, nil
}

View File

@@ -9,25 +9,24 @@ import (
type IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
type AccessControl struct {
// AccessControll middleware to restrict to make POST/PUT/DELETE requests by admin only
type AccessControll struct {
jwtExtractor jwtclaims.ClaimsExtractor
isUserAdmin IsUserAdminFunc
audience string
}
// NewAccessControl instance constructor
func NewAccessControl(audience string, isUserAdmin IsUserAdminFunc) *AccessControl {
return &AccessControl{
// NewAccessControll instance constructor
func NewAccessControll(audience string, isUserAdmin IsUserAdminFunc) *AccessControll {
return &AccessControll{
isUserAdmin: isUserAdmin,
audience: audience,
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
}
}
// Handler method of the middleware which forbids all modify requests for non admin users
// It also adds
func (a *AccessControl) Handler(h http.Handler) http.Handler {
// Handler method of the middleware which forbinneds all modify requests for non admin users
func (a *AccessControll) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jwtClaims := a.jwtExtractor.ExtractClaimsFromRequestContext(r, a.audience)
@@ -35,6 +34,7 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
if err != nil {
http.Error(w, fmt.Sprintf("error get user from JWT: %v", err), http.StatusUnauthorized)
return
}
if !ok {

View File

@@ -186,7 +186,7 @@ func (m *JWTMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Reque
validatedToken, err := m.ValidateAndParse(token)
if err != nil {
m.Options.ErrorHandler(w, r, err.Error())
m.Options.ErrorHandler(w, r, "The token isn't valid")
return err
}

View File

@@ -1,286 +0,0 @@
package http
import (
"encoding/json"
"fmt"
"github.com/gorilla/mux"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
log "github.com/sirupsen/logrus"
"net/http"
)
// Nameservers is the nameserver group handler of the account
type Nameservers struct {
jwtExtractor jwtclaims.ClaimsExtractor
accountManager server.AccountManager
authAudience string
}
// NewNameservers returns a new instance of Nameservers handler
func NewNameservers(accountManager server.AccountManager, authAudience string) *Nameservers {
return &Nameservers{
accountManager: accountManager,
authAudience: authAudience,
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
}
}
// GetAllNameserversHandler returns the list of nameserver groups for the account
func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
nsGroups, err := h.accountManager.ListNameServerGroups(account.Id)
if err != nil {
toHTTPError(err, w)
return
}
apiNameservers := make([]*api.NameserverGroup, 0)
for _, r := range nsGroups {
apiNameservers = append(apiNameservers, toNameserverGroupResponse(r))
}
writeJSONObject(w, apiNameservers)
}
// CreateNameserverGroupHandler handles nameserver group creation request
func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
var req api.PostApiDnsNameserversJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
nsList, err := toServerNSList(req.Nameservers)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled)
if err != nil {
toHTTPError(err, w)
return
}
resp := toNameserverGroupResponse(nsGroup)
writeJSONObject(w, &resp)
}
// UpdateNameserverGroupHandler handles update to a nameserver group identified by a given ID
func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
nsGroupID := mux.Vars(r)["id"]
if len(nsGroupID) == 0 {
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest)
return
}
var req api.PutApiDnsNameserversIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
nsList, err := toServerNSList(req.Nameservers)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
updatedNSGroup := &nbdns.NameServerGroup{
ID: nsGroupID,
Name: req.Name,
Description: req.Description,
NameServers: nsList,
Groups: req.Groups,
Enabled: req.Enabled,
}
err = h.accountManager.SaveNameServerGroup(account.Id, updatedNSGroup)
if err != nil {
toHTTPError(err, w)
return
}
resp := toNameserverGroupResponse(updatedNSGroup)
writeJSONObject(w, &resp)
}
// PatchNameserverGroupHandler handles patch updates to a nameserver group identified by a given ID
func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
nsGroupID := mux.Vars(r)["id"]
if len(nsGroupID) == 0 {
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest)
return
}
var req api.PatchApiDnsNameserversIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
var operations []server.NameServerGroupUpdateOperation
for _, patch := range req {
if patch.Op != api.NameserverGroupPatchOperationOpReplace {
http.Error(w, fmt.Sprintf("nameserver groups only accepts replace operations, got %s", patch.Op),
http.StatusBadRequest)
return
}
switch patch.Path {
case api.NameserverGroupPatchOperationPathName:
operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupName,
Values: patch.Value,
})
case api.NameserverGroupPatchOperationPathDescription:
operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupDescription,
Values: patch.Value,
})
case api.NameserverGroupPatchOperationPathNameservers:
operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupNameServers,
Values: patch.Value,
})
case api.NameserverGroupPatchOperationPathGroups:
operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupGroups,
Values: patch.Value,
})
case api.NameserverGroupPatchOperationPathEnabled:
operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupEnabled,
Values: patch.Value,
})
default:
http.Error(w, "invalid patch path", http.StatusBadRequest)
return
}
}
updatedNSGroup, err := h.accountManager.UpdateNameServerGroup(account.Id, nsGroupID, operations)
if err != nil {
toHTTPError(err, w)
return
}
resp := toNameserverGroupResponse(updatedNSGroup)
writeJSONObject(w, &resp)
}
// DeleteNameserverGroupHandler handles nameserver group deletion request
func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
nsGroupID := mux.Vars(r)["id"]
if len(nsGroupID) == 0 {
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest)
return
}
err = h.accountManager.DeleteNameServerGroup(account.Id, nsGroupID)
if err != nil {
toHTTPError(err, w)
return
}
writeJSONObject(w, "")
}
// GetNameserverGroupHandler handles a nameserver group Get request identified by ID
func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
nsGroupID := mux.Vars(r)["id"]
if len(nsGroupID) == 0 {
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest)
return
}
nsGroup, err := h.accountManager.GetNameServerGroup(account.Id, nsGroupID)
if err != nil {
toHTTPError(err, w)
return
}
resp := toNameserverGroupResponse(nsGroup)
writeJSONObject(w, &resp)
}
func toServerNSList(apiNSList []api.Nameserver) ([]nbdns.NameServer, error) {
var nsList []nbdns.NameServer
for _, apiNS := range apiNSList {
parsed, err := nbdns.ParseNameServerURL(fmt.Sprintf("%s://%s:%d", apiNS.NsType, apiNS.Ip, apiNS.Port))
if err != nil {
return nil, err
}
nsList = append(nsList, parsed)
}
return nsList, nil
}
func toNameserverGroupResponse(serverNSGroup *nbdns.NameServerGroup) *api.NameserverGroup {
var nsList []api.Nameserver
for _, ns := range serverNSGroup.NameServers {
apiNS := api.Nameserver{
Ip: ns.IP.String(),
NsType: api.NameserverNsType(ns.NSType.String()),
Port: ns.Port,
}
nsList = append(nsList, apiNS)
}
return &api.NameserverGroup{
Id: serverNSGroup.ID,
Name: serverNSGroup.Name,
Description: serverNSGroup.Description,
Groups: serverNSGroup.Groups,
Nameservers: nsList,
Enabled: serverNSGroup.Enabled,
}
}

View File

@@ -1,293 +0,0 @@
package http
import (
"bytes"
"encoding/json"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"io"
"net/http"
"net/http/httptest"
"net/netip"
"testing"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
)
const (
existingNSGroupID = "existingNSGroupID"
notFoundNSGroupID = "notFoundNSGroupID"
testNSGroupAccountID = "test_id"
)
var testingNSAccount = &server.Account{
Id: testNSGroupAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
},
}
var baseExistingNSGroup = &nbdns.NameServerGroup{
ID: existingNSGroupID,
Name: "super",
Description: "super",
Primary: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
Groups: []string{"testing"},
Enabled: true,
}
func initNameserversTestData() *Nameservers {
return &Nameservers{
accountManager: &mock_server.MockAccountManager{
GetNameServerGroupFunc: func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
if nsGroupID == existingNSGroupID {
return baseExistingNSGroup.Copy(), nil
}
return nil, status.Errorf(codes.NotFound, "nameserver group with ID %s not found", nsGroupID)
},
CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) {
return &nbdns.NameServerGroup{
ID: existingNSGroupID,
Name: name,
Description: description,
NameServers: nameServerList,
Groups: groups,
Enabled: enabled,
Primary: primary,
Domains: domains,
}, nil
},
DeleteNameServerGroupFunc: func(accountID, nsGroupID string) error {
return nil
},
SaveNameServerGroupFunc: func(accountID string, nsGroupToSave *nbdns.NameServerGroup) error {
if nsGroupToSave.ID == existingNSGroupID {
return nil
}
return status.Errorf(codes.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
},
UpdateNameServerGroupFunc: func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
nsGroupToUpdate := baseExistingNSGroup.Copy()
if nsGroupID != nsGroupToUpdate.ID {
return nil, status.Errorf(codes.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
}
for _, operation := range operations {
switch operation.Type {
case server.UpdateNameServerGroupName:
nsGroupToUpdate.Name = operation.Values[0]
case server.UpdateNameServerGroupDescription:
nsGroupToUpdate.Description = operation.Values[0]
case server.UpdateNameServerGroupNameServers:
var parsedNSList []nbdns.NameServer
for _, nsURL := range operation.Values {
parsed, err := nbdns.ParseNameServerURL(nsURL)
if err != nil {
return nil, err
}
parsedNSList = append(parsedNSList, parsed)
}
nsGroupToUpdate.NameServers = parsedNSList
}
}
return nsGroupToUpdate, nil
},
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) {
return testingNSAccount, nil
},
},
authAudience: "",
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: testNSGroupAccountID,
}
},
},
}
}
func TestNameserversHandlers(t *testing.T) {
tt := []struct {
name string
expectedStatus int
expectedBody bool
expectedNSGroup *api.NameserverGroup
requestType string
requestPath string
requestBody io.Reader
}{
{
name: "Get Existing Nameserver Group",
requestType: http.MethodGet,
requestPath: "/api/dns/nameservers/" + existingNSGroupID,
expectedStatus: http.StatusOK,
expectedBody: true,
expectedNSGroup: toNameserverGroupResponse(baseExistingNSGroup),
},
{
name: "Get Not Existing Nameserver Group",
requestType: http.MethodGet,
requestPath: "/api/dns/nameservers/" + notFoundNSGroupID,
expectedStatus: http.StatusNotFound,
},
{
name: "POST OK",
requestType: http.MethodPost,
requestPath: "/api/dns/nameservers",
requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedNSGroup: &api.NameserverGroup{
Id: existingNSGroupID,
Name: "name",
Description: "Post",
Nameservers: []api.Nameserver{
{
Ip: "1.1.1.1",
NsType: "udp",
Port: 53,
},
},
Groups: []string{"group"},
Enabled: true,
},
},
{
name: "POST Invalid Nameserver",
requestType: http.MethodPost,
requestPath: "/api/dns/nameservers",
requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1000\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
expectedStatus: http.StatusBadRequest,
expectedBody: false,
},
{
name: "PUT OK",
requestType: http.MethodPut,
requestPath: "/api/dns/nameservers/" + existingNSGroupID,
requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedNSGroup: &api.NameserverGroup{
Id: existingNSGroupID,
Name: "name",
Description: "Post",
Nameservers: []api.Nameserver{
{
Ip: "1.1.1.1",
NsType: "udp",
Port: 53,
},
},
Groups: []string{"group"},
Enabled: true,
},
},
{
name: "PUT Not Existing Nameserver Group",
requestType: http.MethodPut,
requestPath: "/api/dns/nameservers/" + notFoundNSGroupID,
requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
expectedStatus: http.StatusNotFound,
expectedBody: false,
},
{
name: "PUT Invalid Nameserver",
requestType: http.MethodPut,
requestPath: "/api/dns/nameservers/" + notFoundNSGroupID,
requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"100\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
expectedStatus: http.StatusBadRequest,
expectedBody: false,
},
{
name: "PATCH OK",
requestType: http.MethodPatch,
requestPath: "/api/dns/nameservers/" + existingNSGroupID,
requestBody: bytes.NewBufferString("[{\"op\":\"replace\",\"path\":\"description\",\"value\":[\"NewDesc\"]}]"),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedNSGroup: &api.NameserverGroup{
Id: existingNSGroupID,
Name: baseExistingNSGroup.Name,
Description: "NewDesc",
Nameservers: toNameserverGroupResponse(baseExistingNSGroup).Nameservers,
Groups: baseExistingNSGroup.Groups,
Enabled: baseExistingNSGroup.Enabled,
},
},
{
name: "PATCH Invalid Nameserver Group OK",
requestType: http.MethodPatch,
requestPath: "/api/dns/nameservers/" + notFoundRouteID,
requestBody: bytes.NewBufferString("[{\"op\":\"replace\",\"path\":\"description\",\"value\":[\"NewDesc\"]}]"),
expectedStatus: http.StatusNotFound,
expectedBody: false,
},
}
p := initNameserversTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/dns/nameservers/{id}", p.GetNameserverGroupHandler).Methods("GET")
router.HandleFunc("/api/dns/nameservers", p.CreateNameserverGroupHandler).Methods("POST")
router.HandleFunc("/api/dns/nameservers/{id}", p.DeleteNameserverGroupHandler).Methods("DELETE")
router.HandleFunc("/api/dns/nameservers/{id}", p.UpdateNameserverGroupHandler).Methods("PUT")
router.HandleFunc("/api/dns/nameservers/{id}", p.PatchNameserverGroupHandler).Methods("PATCH")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
status, tc.expectedStatus, string(content))
return
}
if !tc.expectedBody {
return
}
got := &api.NameserverGroup{}
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, tc.expectedNSGroup, got)
})
}
}

View File

@@ -11,7 +11,7 @@ import (
"net/http"
)
// Peers is a handler that returns peers of the account
//Peers is a handler that returns peers of the account
type Peers struct {
accountManager server.AccountManager
authAudience string
@@ -56,7 +56,7 @@ func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseW
}
func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -95,20 +95,15 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
peers, err := h.accountManager.GetPeers(account.Id, user.Id)
if err != nil {
return
}
respBody := []*api.Peer{}
for _, peer := range peers {
for _, peer := range account.Peers {
respBody = append(respBody, toPeerResponse(peer, account))
}
writeJSONObject(w, respBody)
@@ -149,8 +144,5 @@ func toPeerResponse(peer *server.Peer, account *server.Account) *api.Peer {
Version: peer.Meta.WtVersion,
Groups: groupsInfo,
SshEnabled: peer.SSHEnabled,
Hostname: peer.Meta.Hostname,
UserId: &peer.UserID,
UiVersion: &peer.Meta.UIVersion,
}
}

View File

@@ -16,21 +16,15 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
)
func initTestMetaData(peers ...*server.Peer) *Peers {
func initTestMetaData(peer ...*server.Peer) *Peers {
return &Peers{
accountManager: &mock_server.MockAccountManager{
GetPeersFunc: func(accountID, userID string) ([]*server.Peer, error) {
return peers, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
Peers: map[string]*server.Peer{
"test_peer": peers[0],
},
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
"test_peer": peer[0],
},
}, nil
},

View File

@@ -33,24 +33,16 @@ func NewRoutes(accountManager server.AccountManager, authAudience string) *Route
// GetAllRoutesHandler returns the list of routes for the account
func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
routes, err := h.accountManager.ListRoutes(account.Id, user.Id)
routes, err := h.accountManager.ListRoutes(account.Id)
if err != nil {
log.Error(err)
if e, ok := server.FromError(err); ok {
switch e.Type() {
case server.PermissionDenied:
http.Error(w, e.Error(), http.StatusForbidden)
return
default:
}
}
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
@@ -64,7 +56,7 @@ func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
// CreateRouteHandler handles route creation request
func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -112,7 +104,7 @@ func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
// UpdateRouteHandler handles update to a route identified by a given ID
func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -125,13 +117,13 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
return
}
_, err = h.accountManager.GetRoute(account.Id, routeID, "")
_, err = h.accountManager.GetRoute(account.Id, routeID)
if err != nil {
http.Error(w, fmt.Sprintf("couldn't find route for ID %s", routeID), http.StatusNotFound)
return
}
var req api.PutApiRoutesIdJSONRequestBody
var req api.PutApiRoutesIdJSONBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
@@ -185,7 +177,7 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
// PatchRouteHandler handles patch updates to a route identified by a given ID
func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -198,7 +190,7 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
return
}
_, err = h.accountManager.GetRoute(account.Id, routeID, "")
_, err = h.accountManager.GetRoute(account.Id, routeID)
if err != nil {
log.Error(err)
http.Error(w, fmt.Sprintf("couldn't find route ID %s", routeID), http.StatusNotFound)
@@ -342,7 +334,7 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
// DeleteRouteHandler handles route deletion request
func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -356,11 +348,6 @@ func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
err = h.accountManager.DeleteRoute(account.Id, routeID)
if err != nil {
errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, fmt.Sprintf("route %s not found under account %s", routeID, account.Id), http.StatusNotFound)
return
}
log.Errorf("failed delete route %s under account %s %v", routeID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -371,7 +358,7 @@ func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
// GetRouteHandler handles a route Get request identified by ID
func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) {
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -383,7 +370,7 @@ func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) {
return
}
foundRoute, err := h.accountManager.GetRoute(account.Id, routeID, user.Id)
foundRoute, err := h.accountManager.GetRoute(account.Id, routeID)
if err != nil {
http.Error(w, "route not found", http.StatusNotFound)
return

View File

@@ -51,15 +51,12 @@ var testingAccount = &server.Account{
IP: netip.MustParseAddr(existingPeerID).AsSlice(),
},
},
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
},
}
func initRoutesTestData() *Routes {
return &Routes{
accountManager: &mock_server.MockAccountManager{
GetRouteFunc: func(_, routeID, _ string) (*route.Route, error) {
GetRouteFunc: func(_, routeID string) (*route.Route, error) {
if routeID == existingRouteID {
return baseExistingRoute, nil
}
@@ -81,10 +78,7 @@ func initRoutesTestData() *Routes {
SaveRouteFunc: func(_ string, _ *route.Route) error {
return nil
},
DeleteRouteFunc: func(_ string, peerIP string) error {
if peerIP != existingRouteID {
return status.Errorf(codes.NotFound, "Peer with ID %s not found", peerIP)
}
DeleteRouteFunc: func(_ string, _ string) error {
return nil
},
GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) {
@@ -123,7 +117,7 @@ func initRoutesTestData() *Routes {
}
return routeToUpdate, nil
},
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) {
GetAccountWithAuthorizationClaimsFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) {
return testingAccount, nil
},
},
@@ -161,7 +155,7 @@ func TestRoutesHandlers(t *testing.T) {
{
name: "Get Not Existing Route",
requestType: http.MethodGet,
requestPath: "/api/routes/" + notFoundRouteID,
requestPath: "/api/rules/" + notFoundRouteID,
expectedStatus: http.StatusNotFound,
},
{
@@ -174,7 +168,7 @@ func TestRoutesHandlers(t *testing.T) {
{
name: "Delete Not Existing Route",
requestType: http.MethodDelete,
requestPath: "/api/routes/" + notFoundRouteID,
requestPath: "/api/rules/" + notFoundRouteID,
expectedStatus: http.StatusNotFound,
},
{

View File

@@ -31,29 +31,15 @@ func NewRules(accountManager server.AccountManager, authAudience string) *Rules
// GetAllRulesHandler list for the account
func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountRules, err := h.accountManager.ListRules(account.Id, user.Id)
if err != nil {
log.Error(err)
if e, ok := server.FromError(err); ok {
switch e.Type() {
case server.PermissionDenied:
http.Error(w, e.Error(), http.StatusForbidden)
return
default:
}
}
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
rules := []*api.Rule{}
for _, r := range accountRules {
for _, r := range account.Rules {
rules = append(rules, toRuleResponse(account, r))
}
@@ -62,7 +48,7 @@ func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
// UpdateRuleHandler handles update to a rule identified by a given ID
func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -132,7 +118,7 @@ func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
// PatchRuleHandler handles patch updates to a rule identified by a given ID
func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -289,7 +275,7 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
// CreateRuleHandler handles rule creation request
func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -346,7 +332,7 @@ func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
// DeleteRuleHandler handles rule deletion request
func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -370,7 +356,7 @@ func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
// GetRuleHandler handles a group Get request identified by ID
func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -384,7 +370,7 @@ func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
return
}
rule, err := h.accountManager.GetRule(account.Id, ruleID, user.Id)
rule, err := h.accountManager.GetRule(account.Id, ruleID)
if err != nil {
http.Error(w, "rule not found", http.StatusNotFound)
return

View File

@@ -28,7 +28,7 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
}
return nil
},
GetRuleFunc: func(_, ruleID, _ string) (*server.Rule, error) {
GetRuleFunc: func(_, ruleID string) (*server.Rule, error) {
if ruleID != "idoftherule" {
return nil, fmt.Errorf("not found")
}
@@ -66,17 +66,14 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
}
return &rule, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
Rules: map[string]*server.Rule{"id-existed": &server.Rule{ID: "id-existed"}},
Groups: map[string]*server.Group{
"F": {ID: "F"},
"G": {ID: "G"},
},
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
"F": &server.Group{ID: "F"},
"G": &server.Group{ID: "G"},
},
}, nil
},

View File

@@ -2,7 +2,6 @@ package http
import (
"encoding/json"
"fmt"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
@@ -29,17 +28,54 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authAudience stri
}
}
// CreateSetupKeyHandler is a POST requests that creates a new SetupKey
func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
func (h *SetupKeys) updateKey(accountId string, keyId string, w http.ResponseWriter, r *http.Request) {
req := &api.PutApiSetupKeysIdJSONRequestBody{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
var key *server.SetupKey
if req.Revoked {
//handle only if being revoked, don't allow to enable key again for now
key, err = h.accountManager.RevokeSetupKey(accountId, keyId)
if err != nil {
http.Error(w, "failed revoking key", http.StatusInternalServerError)
return
}
}
if len(req.Name) != 0 {
key, err = h.accountManager.RenameSetupKey(accountId, keyId, req.Name)
if err != nil {
http.Error(w, "failed renaming key", http.StatusInternalServerError)
return
}
}
if key != nil {
writeSuccess(w, key)
}
}
func (h *SetupKeys) getKey(accountId string, keyId string, w http.ResponseWriter, r *http.Request) {
account, err := h.accountManager.GetAccountById(accountId)
if err != nil {
http.Error(w, "account doesn't exist", http.StatusInternalServerError)
return
}
for _, key := range account.SetupKeys {
if key.Id == keyId {
writeSuccess(w, key)
return
}
}
http.Error(w, "setup key not found", http.StatusNotFound)
}
func (h *SetupKeys) createKey(accountId string, w http.ResponseWriter, r *http.Request) {
req := &api.PostApiSetupKeysJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
@@ -59,13 +95,7 @@ func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request
expiresIn := time.Duration(req.ExpiresIn) * time.Second
if req.AutoGroups == nil {
req.AutoGroups = []string{}
}
// newExpiresIn := time.Duration(req.ExpiresIn) * time.Second
// newKey.ExpiresAt = time.Now().Add(newExpiresIn)
setupKey, err := h.accountManager.CreateSetupKey(account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
req.AutoGroups)
setupKey, err := h.accountManager.AddSetupKey(accountId, req.Name, server.SetupKeyType(req.Type), expiresIn)
if err != nil {
errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.NotFound {
@@ -79,9 +109,8 @@ func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request
writeSuccess(w, setupKey)
}
// GetSetupKeyHandler is a GET request to get a SetupKey by ID
func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -89,104 +118,55 @@ func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
}
vars := mux.Vars(r)
keyID := vars["id"]
if len(keyID) == 0 {
keyId := vars["id"]
if len(keyId) == 0 {
http.Error(w, "invalid key Id", http.StatusBadRequest)
return
}
key, err := h.accountManager.GetSetupKey(account.Id, user.Id, keyID)
switch r.Method {
case http.MethodPut:
h.updateKey(account.Id, keyId, w, r)
return
case http.MethodGet:
h.getKey(account.Id, keyId, w, r)
return
default:
http.Error(w, "", http.StatusNotFound)
}
}
func (h *SetupKeys) GetKeys(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, fmt.Sprintf("setup key %s not found under account %s", keyID, account.Id), http.StatusNotFound)
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
switch r.Method {
case http.MethodPost:
h.createKey(account.Id, w, r)
return
case http.MethodGet:
w.WriteHeader(200)
w.Header().Set("Content-Type", "application/json")
respBody := []*api.SetupKey{}
for _, key := range account.SetupKeys {
respBody = append(respBody, toResponseBody(key))
}
err = json.NewEncoder(w).Encode(respBody)
if err != nil {
log.Errorf("failed encoding account peers %s: %v", account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
log.Errorf("failed getting setup key %s under account %s %v", keyID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
default:
http.Error(w, "", http.StatusNotFound)
}
writeSuccess(w, key)
}
// UpdateSetupKeyHandler is a PUT request to update server.SetupKey
func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
vars := mux.Vars(r)
keyID := vars["id"]
if len(keyID) == 0 {
http.Error(w, "invalid key Id", http.StatusBadRequest)
return
}
req := &api.PutApiSetupKeysIdJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if req.Name == "" {
http.Error(w, fmt.Sprintf("setup key name field is invalid: %s", req.Name), http.StatusBadRequest)
return
}
if req.AutoGroups == nil {
http.Error(w, fmt.Sprintf("setup key AutoGroups field is invalid: %s", req.AutoGroups), http.StatusBadRequest)
return
}
newKey := &server.SetupKey{}
newKey.AutoGroups = req.AutoGroups
newKey.Revoked = req.Revoked
newKey.Name = req.Name
newKey.Id = keyID
newKey, err = h.accountManager.SaveSetupKey(account.Id, newKey)
if err != nil {
if e, ok := status.FromError(err); ok {
switch e.Code() {
case codes.NotFound:
http.Error(w, fmt.Sprintf("couldn't find setup key for ID %s", keyID), http.StatusNotFound)
default:
http.Error(w, "failed updating setup key", http.StatusInternalServerError)
}
}
return
}
writeSuccess(w, newKey)
}
// GetAllSetupKeysHandler is a GET request that returns a list of SetupKey
func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) {
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
setupKeys, err := h.accountManager.ListSetupKeys(account.Id, user.Id)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
apiSetupKeys := make([]*api.SetupKey, 0)
for _, key := range setupKeys {
apiSetupKeys = append(apiSetupKeys, toResponseBody(key))
}
writeJSONObject(w, apiSetupKeys)
}
func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
@@ -210,19 +190,16 @@ func toResponseBody(key *server.SetupKey) *api.SetupKey {
} else {
state = "valid"
}
return &api.SetupKey{
Id: key.Id,
Key: key.Key,
Name: key.Name,
Expires: key.ExpiresAt,
Type: string(key.Type),
Valid: key.IsValid(),
Revoked: key.Revoked,
UsedTimes: key.UsedTimes,
LastUsed: key.LastUsed,
State: state,
AutoGroups: key.AutoGroups,
UpdatedAt: key.UpdatedAt,
Id: key.Id,
Key: key.Key,
Name: key.Name,
Expires: key.ExpiresAt,
Type: string(key.Type),
Valid: key.IsValid(),
Revoked: key.Revoked,
UsedTimes: key.UsedTimes,
LastUsed: key.LastUsed,
State: state,
}
}

View File

@@ -1,228 +0,0 @@
package http
import (
"bytes"
"encoding/json"
"fmt"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server"
)
const (
existingSetupKeyID = "existingSetupKeyID"
newSetupKeyName = "New Setup Key"
updatedSetupKeyName = "KKKey"
notFoundSetupKeyID = "notFoundSetupKeyID"
)
func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey,
user *server.User) *SetupKeys {
return &SetupKeys{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
return &server.Account{
Id: testAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
user.Id: user,
},
SetupKeys: map[string]*server.SetupKey{
defaultKey.Key: defaultKey,
},
Groups: map[string]*server.Group{
"group-1": {ID: "group-1", Peers: []string{"A", "B"}},
"id-all": {ID: "id-all", Name: "All"}},
}, nil
},
CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string) (*server.SetupKey, error) {
if keyName == newKey.Name || typ != newKey.Type {
return newKey, nil
}
return nil, fmt.Errorf("failed creating setup key")
},
GetSetupKeyFunc: func(accountID, userID, keyID string) (*server.SetupKey, error) {
switch keyID {
case defaultKey.Id:
return defaultKey, nil
case newKey.Id:
return newKey, nil
default:
return nil, status.Errorf(codes.NotFound, "key %s not found", keyID)
}
},
SaveSetupKeyFunc: func(accountID string, key *server.SetupKey) (*server.SetupKey, error) {
if key.Id == updatedSetupKey.Id {
return updatedSetupKey, nil
}
return nil, status.Errorf(codes.NotFound, "key %s not found", key.Id)
},
ListSetupKeysFunc: func(accountID, userID string) ([]*server.SetupKey, error) {
return []*server.SetupKey{defaultKey}, nil
},
},
authAudience: "",
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudience string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: user.Id,
Domain: "hotmail.com",
AccountId: testAccountID,
}
},
},
}
}
func TestSetupKeysHandlers(t *testing.T) {
defaultSetupKey := server.GenerateDefaultSetupKey()
defaultSetupKey.Id = existingSetupKeyID
adminUser := server.NewAdminUser("test_user")
newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"})
updatedDefaultSetupKey := defaultSetupKey.Copy()
updatedDefaultSetupKey.AutoGroups = []string{"group-1"}
updatedDefaultSetupKey.Name = updatedSetupKeyName
updatedDefaultSetupKey.Revoked = true
tt := []struct {
name string
requestType string
requestPath string
requestBody io.Reader
expectedStatus int
expectedBody bool
expectedSetupKey *api.SetupKey
expectedSetupKeys []*api.SetupKey
}{
{
name: "Get Setup Keys",
requestType: http.MethodGet,
requestPath: "/api/setup-keys",
expectedStatus: http.StatusOK,
expectedBody: true,
expectedSetupKeys: []*api.SetupKey{toResponseBody(defaultSetupKey)},
},
{
name: "Get Existing Setup Key",
requestType: http.MethodGet,
requestPath: "/api/setup-keys/" + existingSetupKeyID,
expectedStatus: http.StatusOK,
expectedBody: true,
expectedSetupKey: toResponseBody(defaultSetupKey),
},
{
name: "Get Not Existing Setup Key",
requestType: http.MethodGet,
requestPath: "/api/setup-keys/" + notFoundSetupKeyID,
expectedStatus: http.StatusNotFound,
expectedBody: false,
},
{
name: "Create Setup Key",
requestType: http.MethodPost,
requestPath: "/api/setup-keys",
requestBody: bytes.NewBuffer(
[]byte(fmt.Sprintf("{\"name\":\"%s\",\"type\":\"%s\"}", newSetupKey.Name, newSetupKey.Type))),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedSetupKey: toResponseBody(newSetupKey),
},
{
name: "Update Setup Key",
requestType: http.MethodPut,
requestPath: "/api/setup-keys/" + defaultSetupKey.Id,
requestBody: bytes.NewBuffer(
[]byte(fmt.Sprintf("{\"name\":\"%s\",\"auto_groups\":[\"%s\"], \"revoked\":%v}",
updatedDefaultSetupKey.Type,
updatedDefaultSetupKey.AutoGroups[0],
updatedDefaultSetupKey.Revoked,
))),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedSetupKey: toResponseBody(updatedDefaultSetupKey),
},
}
handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey, adminUser)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/setup-keys", handler.GetAllSetupKeysHandler).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys", handler.CreateSetupKeyHandler).Methods("POST", "OPTIONS")
router.HandleFunc("/api/setup-keys/{id}", handler.GetSetupKeyHandler).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys/{id}", handler.UpdateSetupKeyHandler).Methods("PUT", "OPTIONS")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
status, tc.expectedStatus, string(content))
return
}
if !tc.expectedBody {
return
}
if tc.expectedSetupKey != nil {
got := &api.SetupKey{}
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assertKeys(t, got, tc.expectedSetupKey)
return
}
if len(tc.expectedSetupKeys) > 0 {
var got []*api.SetupKey
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assertKeys(t, got[0], tc.expectedSetupKeys[0])
return
}
})
}
}
func assertKeys(t *testing.T, got *api.SetupKey, expected *api.SetupKey) {
// this comparison is done manually because when converting to JSON dates formatted differently
// assert.Equal(t, got.UpdatedAt, tc.expectedSetupKey.UpdatedAt) //doesn't work
assert.WithinDurationf(t, got.UpdatedAt, expected.UpdatedAt, 0, "")
assert.WithinDurationf(t, got.Expires, expected.Expires, 0, "")
assert.Equal(t, got.Name, expected.Name)
assert.Equal(t, got.Id, expected.Id)
assert.Equal(t, got.Key, expected.Key)
assert.Equal(t, got.Type, expected.Type)
assert.Equal(t, got.UsedTimes, expected.UsedTimes)
assert.Equal(t, got.Revoked, expected.Revoked)
assert.ElementsMatch(t, got.AutoGroups, expected.AutoGroups)
}

View File

@@ -1,15 +1,11 @@
package http
import (
"encoding/json"
"fmt"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net/http"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
)
@@ -28,103 +24,6 @@ func NewUserHandler(accountManager server.AccountManager, authAudience string) *
}
}
// UpdateUser is a PUT requests to update User data
func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
http.Error(w, "", http.StatusBadRequest)
}
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
vars := mux.Vars(r)
userID := vars["id"]
if len(userID) == 0 {
http.Error(w, "invalid user ID", http.StatusBadRequest)
return
}
req := &api.PutApiUsersIdJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
userRole := server.StrRoleToUserRole(req.Role)
if userRole == server.UserRoleUnknown {
http.Error(w, "invalid user role", http.StatusBadRequest)
return
}
newUser, err := h.accountManager.SaveUser(account.Id, &server.User{
Id: userID,
Role: userRole,
AutoGroups: req.AutoGroups,
})
if err != nil {
if e, ok := status.FromError(err); ok {
switch e.Code() {
case codes.NotFound:
http.Error(w, fmt.Sprintf("couldn't find a user for ID %s", userID), http.StatusNotFound)
default:
http.Error(w, "failed to update user", http.StatusInternalServerError)
}
}
return
}
writeJSONObject(w, toUserResponse(newUser))
}
// CreateUserHandler creates a User in the system with a status "invited" (effectively this is a user invite).
func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "", http.StatusNotFound)
}
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
}
req := &api.PostApiUsersJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown {
http.Error(w, "unknown user role "+req.Role, http.StatusBadRequest)
return
}
newUser, err := h.accountManager.CreateUser(account.Id, &server.UserInfo{
Email: req.Email,
Name: *req.Name,
Role: req.Role,
AutoGroups: req.AutoGroups,
})
if err != nil {
if e, ok := server.FromError(err); ok {
switch e.Type() {
case server.UserAlreadyExists:
http.Error(w, "You can't invite users with an existing NetBird account.", http.StatusPreconditionFailed)
return
default:
}
}
http.Error(w, "failed to invite", http.StatusInternalServerError)
return
}
writeJSONObject(w, toUserResponse(newUser))
}
// GetUsers returns a list of users of the account this user belongs to.
// It also gathers additional user data (like email and name) from the IDP manager.
func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
@@ -132,20 +31,19 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
http.Error(w, "", http.StatusBadRequest)
}
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
log.Error(err)
}
data, err := h.accountManager.GetUsersFromAccount(account.Id, user.Id)
data, err := h.accountManager.GetUsersFromAccount(account.Id)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
users := make([]*api.User, 0)
users := []*api.User{}
for _, r := range data {
users = append(users, toUserResponse(r))
}
@@ -154,28 +52,10 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
}
func toUserResponse(user *server.UserInfo) *api.User {
autoGroups := user.AutoGroups
if autoGroups == nil {
autoGroups = []string{}
}
var userStatus api.UserStatus
switch user.Status {
case "active":
userStatus = api.UserStatusActive
case "invited":
userStatus = api.UserStatusInvited
default:
userStatus = api.UserStatusDisabled
}
return &api.User{
Id: user.ID,
Name: user.Name,
Email: user.Email,
Role: user.Role,
AutoGroups: autoGroups,
Status: userStatus,
Id: user.ID,
Name: user.Name,
Email: user.Email,
Role: user.Role,
}
}

View File

@@ -16,7 +16,7 @@ import (
func initUsers(user ...*server.User) *UserHandler {
return &UserHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
users := make(map[string]*server.User, 0)
for _, u := range user {
users[u.Id] = u
@@ -27,7 +27,7 @@ func initUsers(user ...*server.User) *UserHandler {
Users: users,
}, nil
},
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
GetUsersFromAccountFunc: func(accountID string) ([]*server.UserInfo, error) {
users := make([]*server.UserInfo, 0)
for _, v := range user {
users = append(users, &server.UserInfo{
@@ -44,7 +44,7 @@ func initUsers(user ...*server.User) *UserHandler {
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "1",
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
}

View File

@@ -6,14 +6,11 @@ import (
"fmt"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net/http"
"time"
)
// writeJSONObject simply writes object to the HTTP reponse in JSON format
//writeJSONObject simply writes object to the HTTP reponse in JSON format
func writeJSONObject(w http.ResponseWriter, obj interface{}) {
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
@@ -24,7 +21,7 @@ func writeJSONObject(w http.ResponseWriter, obj interface{}) {
}
}
// Duration is used strictly for JSON requests/responses due to duration marshalling issues
//Duration is used strictly for JSON requests/responses due to duration marshalling issues
type Duration struct {
time.Duration
}
@@ -56,42 +53,14 @@ func (d *Duration) UnmarshalJSON(b []byte) error {
func getJWTAccount(accountManager server.AccountManager,
jwtExtractor jwtclaims.ClaimsExtractor,
authAudience string, r *http.Request) (*server.Account, *server.User, error) {
authAudience string, r *http.Request) (*server.Account, error) {
claims := jwtExtractor.ExtractClaimsFromRequestContext(r, authAudience)
jwtClaims := jwtExtractor.ExtractClaimsFromRequestContext(r, authAudience)
account, err := accountManager.GetAccountFromToken(claims)
account, err := accountManager.GetAccountWithAuthorizationClaims(jwtClaims)
if err != nil {
return nil, nil, fmt.Errorf("failed getting account of a user %s: %v", claims.UserId, err)
return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err)
}
user := account.Users[claims.UserId]
if user == nil {
// this is not really possible because we got an account by user ID
return nil, nil, fmt.Errorf("user %s not found", claims.UserId)
}
return account, user, nil
}
func toHTTPError(err error, w http.ResponseWriter) {
errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.Internal {
http.Error(w, errStatus.String(), http.StatusInternalServerError)
return
}
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, errStatus.String(), http.StatusNotFound)
return
}
if ok && errStatus.Code() == codes.InvalidArgument {
http.Error(w, errStatus.String(), http.StatusBadRequest)
return
}
unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", errStatus.String())
log.Error(unhandledMSG)
http.Error(w, unhandledMSG, http.StatusInternalServerError)
return account, nil
}

View File

@@ -6,8 +6,8 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/netbirdio/netbird/management/server/telemetry"
"io"
"io/ioutil"
"net/http"
"net/url"
"strconv"
@@ -25,7 +25,6 @@ type Auth0Manager struct {
httpClient ManagerHTTPClient
credentials ManagerCredentials
helper ManagerHelper
appMetrics telemetry.AppMetrics
}
// Auth0ClientConfig auth0 manager client configurations
@@ -53,17 +52,6 @@ type Auth0Credentials struct {
httpClient ManagerHTTPClient
jwtToken JWTToken
mux sync.Mutex
appMetrics telemetry.AppMetrics
}
// createUserRequest is a user create request
type createUserRequest struct {
Email string `json:"email"`
Name string `json:"name"`
AppMeta AppMetadata `json:"app_metadata"`
Connection string `json:"connection"`
Password string `json:"password"`
VerifyEmail bool `json:"verify_email"`
}
// userExportJobRequest is a user export request struct
@@ -99,17 +87,16 @@ type userExportJobStatusResponse struct {
// auth0Profile represents an Auth0 user profile response
type auth0Profile struct {
AccountID string `json:"wt_account_id"`
PendingInvite bool `json:"wt_pending_invite"`
UserID string `json:"user_id"`
Name string `json:"name"`
Email string `json:"email"`
CreatedAt string `json:"created_at"`
LastLogin string `json:"last_login"`
AccountID string `json:"wt_account_id"`
UserID string `json:"user_id"`
Name string `json:"name"`
Email string `json:"email"`
CreatedAt string `json:"created_at"`
LastLogin string `json:"last_login"`
}
// NewAuth0Manager creates a new instance of the Auth0Manager
func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics) (*Auth0Manager, error) {
func NewAuth0Manager(config Auth0ClientConfig) (*Auth0Manager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
@@ -137,15 +124,12 @@ func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics)
clientConfig: config,
httpClient: httpClient,
helper: helper,
appMetrics: appMetrics,
}
return &Auth0Manager{
authIssuer: config.AuthIssuer,
credentials: credentials,
httpClient: httpClient,
helper: helper,
appMetrics: appMetrics,
}, nil
}
@@ -176,9 +160,6 @@ func (c *Auth0Credentials) requestJWTToken() (*http.Response, error) {
res, err = c.httpClient.Do(req)
if err != nil {
if c.appMetrics != nil {
c.appMetrics.IDPMetrics().CountRequestError()
}
return res, err
}
@@ -191,7 +172,7 @@ func (c *Auth0Credentials) requestJWTToken() (*http.Response, error) {
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
func (c *Auth0Credentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
jwtToken := JWTToken{}
body, err := io.ReadAll(rawBody)
body, err := ioutil.ReadAll(rawBody)
if err != nil {
return jwtToken, err
}
@@ -223,10 +204,6 @@ func (c *Auth0Credentials) Authenticate() (JWTToken, error) {
c.mux.Lock()
defer c.mux.Unlock()
if c.appMetrics != nil {
c.appMetrics.IDPMetrics().CountAuthenticate()
}
// If jwtToken has an expires time and we have enough time to do a request return immediately
if c.jwtStillValid() {
return c.jwtToken, nil
@@ -253,7 +230,7 @@ func (c *Auth0Credentials) Authenticate() (JWTToken, error) {
return c.jwtToken, nil
}
func batchRequestUsersURL(authIssuer, accountID string, page int, perPage int) (string, url.Values, error) {
func batchRequestUsersURL(authIssuer, accountID string, page int) (string, url.Values, error) {
u, err := url.Parse(authIssuer + "/api/v2/users")
if err != nil {
return "", nil, err
@@ -261,7 +238,6 @@ func batchRequestUsersURL(authIssuer, accountID string, page int, perPage int) (
q := u.Query()
q.Set("page", strconv.Itoa(page))
q.Set("search_engine", "v3")
q.Set("per_page", strconv.Itoa(perPage))
q.Set("q", "app_metadata.wt_account_id:"+accountID)
u.RawQuery = q.Encode()
@@ -283,9 +259,8 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
// https://auth0.com/docs/manage-users/user-search/retrieve-users-with-get-users-endpoint#limitations
// auth0 limitation of 1000 users via this endpoint
resultsPerPage := 50
for page := 0; page < 20; page++ {
reqURL, query, err := batchRequestUsersURL(am.authIssuer, accountID, page, resultsPerPage)
reqURL, query, err := batchRequestUsersURL(am.authIssuer, accountID, page)
if err != nil {
return nil, err
}
@@ -300,46 +275,38 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
res, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAccount()
}
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
if res.StatusCode != 200 {
return nil, fmt.Errorf("failed requesting user data from IdP %s", string(body))
}
var batch []UserData
err = json.Unmarshal(body, &batch)
if err != nil {
return nil, err
}
log.Debugf("returned user batch for accountID %s on page %d, %v", accountID, page, batch)
log.Debugf("requested batch; %v", batch)
err = res.Body.Close()
if err != nil {
return nil, err
}
for user := range batch {
list = append(list, &batch[user])
if res.StatusCode != 200 {
return nil, fmt.Errorf("unable to request UserData from auth0, statusCode %d", res.StatusCode)
}
if len(batch) == 0 || len(batch) < resultsPerPage {
log.Debugf("finished loading users for accountID %s", accountID)
if len(batch) == 0 {
return list, nil
}
for user := range batch {
list = append(list, &batch[user])
}
}
return list, nil
@@ -362,16 +329,9 @@ func (am *Auth0Manager) GetUserDataByID(userID string, appMetadata AppMetadata)
res, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetUserDataByID()
}
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
@@ -407,12 +367,14 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
reqURL := am.authIssuer + "/api/v2/users/" + userID
data, err := am.helper.Marshal(map[string]any{"app_metadata": appMetadata})
data, err := am.helper.Marshal(appMetadata)
if err != nil {
return err
}
payload := strings.NewReader(string(data))
payloadString := fmt.Sprintf("{\"app_metadata\": %s}", string(data))
payload := strings.NewReader(payloadString)
req, err := http.NewRequest("PATCH", reqURL, payload)
if err != nil {
@@ -421,20 +383,13 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
log.Debugf("updating IdP metadata for user %s", userID)
log.Debugf("updating metadata for user %s", userID)
res, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
defer func() {
err = res.Body.Close()
if err != nil {
@@ -449,28 +404,6 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
return nil
}
func buildCreateUserRequestPayload(email string, name string, accountID string) (string, error) {
invite := true
req := &createUserRequest{
Email: email,
Name: name,
AppMeta: AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &invite,
},
Connection: "Username-Password-Authentication",
Password: GeneratePassword(8, 1, 1, 1),
VerifyEmail: true,
}
str, err := json.Marshal(req)
if err != nil {
return "", err
}
return string(str), nil
}
func buildUserExportRequest() (string, error) {
req := &userExportJobRequest{}
fields := make([]map[string]string, 0)
@@ -484,11 +417,6 @@ func buildUserExportRequest() (string, error) {
"export_as": "wt_account_id",
})
fields = append(fields, map[string]string{
"name": "app_metadata.wt_pending_invite",
"export_as": "wt_pending_invite",
})
req.Format = "json"
req.Fields = fields
@@ -500,46 +428,32 @@ func buildUserExportRequest() (string, error) {
return string(str), nil
}
func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) {
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return nil, err
}
reqURL := am.authIssuer + endpoint
reqURL := am.authIssuer + "/api/v2/jobs/users-exports"
payload := strings.NewReader(payloadStr)
req, err := http.NewRequest("POST", reqURL, payload)
if err != nil {
return nil, err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
return req, nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
payloadString, err := buildUserExportRequest()
if err != nil {
return nil, err
}
payload := strings.NewReader(payloadString)
exportJobReq, err := am.createPostRequest("/api/v2/jobs/users-exports", payloadString)
exportJobReq, err := http.NewRequest("POST", reqURL, payload)
if err != nil {
return nil, err
}
exportJobReq.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
exportJobReq.Header.Add("content-type", "application/json")
jobResp, err := am.httpClient.Do(exportJobReq)
if err != nil {
log.Debugf("Couldn't get job response %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
@@ -550,15 +464,12 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
}
}()
if jobResp.StatusCode != 200 {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to update the appMetadata, statusCode %d", jobResp.StatusCode)
}
var exportJobResp userExportJobResponse
body, err := io.ReadAll(jobResp.Body)
body, err := ioutil.ReadAll(jobResp.Body)
if err != nil {
log.Debugf("Coudln't read export job response; %v", err)
return nil, err
@@ -571,9 +482,6 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
}
if exportJobResp.ID == "" {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("couldn't get an batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
}
@@ -592,96 +500,6 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
return nil, fmt.Errorf("failed extracting user profiles from auth0")
}
// GetUserByEmail searches users with a given email. If no users have been found, this function returns an empty list.
// This function can return multiple users. This is due to the Auth0 internals - there could be multiple users with
// the same email but different connections that are considered as separate accounts (e.g., Google and username/password).
func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return nil, err
}
reqURL := am.authIssuer + "/api/v2/users-by-email?email=" + url.QueryEscape(email)
body, err := doGetReq(am.httpClient, reqURL, jwtToken.AccessToken)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetUserByEmail()
}
userResp := []*UserData{}
err = am.helper.Unmarshal(body, &userResp)
if err != nil {
log.Debugf("Coudln't unmarshal export job response; %v", err)
return nil, err
}
return userResp, nil
}
// CreateUser creates a new user in Auth0 Idp and sends an invite
func (am *Auth0Manager) CreateUser(email string, name string, accountID string) (*UserData, error) {
payloadString, err := buildCreateUserRequestPayload(email, name, accountID)
if err != nil {
return nil, err
}
req, err := am.createPostRequest("/api/v2/users", payloadString)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountCreateUser()
}
resp, err := am.httpClient.Do(req)
if err != nil {
log.Debugf("Couldn't get job response %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer func() {
err = resp.Body.Close()
if err != nil {
log.Errorf("error while closing create user response body: %v", err)
}
}()
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
}
var createResp UserData
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Debugf("Coudln't read export job response; %v", err)
return nil, err
}
err = am.helper.Unmarshal(body, &createResp)
if err != nil {
log.Debugf("Coudln't unmarshal export job response; %v", err)
return nil, err
}
if createResp.ID == "" {
return nil, fmt.Errorf("couldn't create user: response %v", resp)
}
log.Debugf("created user %s in account %s", createResp.ID, accountID)
return &createResp, nil
}
// checkExportJobStatus checks the status of the job created at CreateExportUsersJob.
// If the status is "completed", then return the downloadLink
func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) {
@@ -754,10 +572,6 @@ func (am *Auth0Manager) downloadProfileExport(location string) (map[string][]*Us
ID: profile.UserID,
Name: profile.Name,
Email: profile.Email,
AppMetadata: AppMetadata{
WTAccountID: profile.AccountID,
WTPendingInvite: &profile.PendingInvite,
},
})
}
}
@@ -787,12 +601,13 @@ func doGetReq(client ManagerHTTPClient, url, accessToken string) ([]byte, error)
log.Errorf("error while closing body for url %s: %v", url, err)
}
}()
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
if res.StatusCode != 200 {
return nil, fmt.Errorf("unable to get %s, statusCode %d", url, res.StatusCode)
}
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, err
}
return body, nil
}

View File

@@ -3,9 +3,8 @@ package idp
import (
"encoding/json"
"fmt"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/stretchr/testify/require"
"io"
"io/ioutil"
"net/http"
"strings"
"testing"
@@ -23,13 +22,13 @@ type mockHTTPClient struct {
}
func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
body, err := io.ReadAll(req.Body)
body, err := ioutil.ReadAll(req.Body)
if err == nil {
c.reqBody = string(body)
}
return &http.Response{
StatusCode: c.code,
Body: io.NopCloser(strings.NewReader(c.resBody)),
Body: ioutil.NopCloser(strings.NewReader(c.resBody)),
}, c.err
}
@@ -131,7 +130,7 @@ func TestAuth0_RequestJWTToken(t *testing.T) {
t.Fatal(err)
}
}
body, err := io.ReadAll(res.Body)
body, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err, "unable to read the response body")
jwtToken := JWTToken{}
@@ -179,7 +178,7 @@ func TestAuth0_ParseRequestJWTResponse(t *testing.T) {
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
rawBody := io.NopCloser(strings.NewReader(testCase.inputResBody))
rawBody := ioutil.NopCloser(strings.NewReader(testCase.inputResBody))
config := Auth0ClientConfig{}
@@ -321,7 +320,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
exp := 15
token := newTestJWT(t, exp)
appMetadata := AppMetadata{WTAccountID: "ok"}
appMetadata := AppMetadata{WTAccountId: "ok"}
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
name: "Bad Authentication",
@@ -341,7 +340,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
name: "Bad Status Code",
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":null}}", appMetadata.WTAccountID),
expectedReqBody: fmt.Sprintf("{\"app_metadata\": {\"wt_account_id\":\"%s\"}}", appMetadata.WTAccountId),
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
@@ -364,7 +363,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
name: "Good request",
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":null}}", appMetadata.WTAccountID),
expectedReqBody: fmt.Sprintf("{\"app_metadata\": {\"wt_account_id\":\"%s\"}}", appMetadata.WTAccountId),
appMetadata: appMetadata,
statusCode: 200,
helper: JsonParser{},
@@ -372,23 +371,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
assertErrFuncMessage: "shouldn't return error",
}
invite := true
updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{
name: "Update Pending Invite",
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":true}}", appMetadata.WTAccountID),
appMetadata: AppMetadata{
WTAccountID: "ok",
WTPendingInvite: &invite,
},
statusCode: 200,
helper: JsonParser{},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} {
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2, updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4} {
t.Run(testCase.name, func(t *testing.T) {
jwtReqClient := mockHTTPClient{
resBody: testCase.inputReqBody,
@@ -476,7 +459,7 @@ func TestNewAuth0Manager(t *testing.T) {
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
t.Run(testCase.name, func(t *testing.T) {
_, err := NewAuth0Manager(testCase.inputConfig, &telemetry.MockAppMetrics{})
_, err := NewAuth0Manager(testCase.inputConfig)
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
})
}

View File

@@ -2,7 +2,6 @@ package idp
import (
"fmt"
"github.com/netbirdio/netbird/management/server/telemetry"
"net/http"
"strings"
"time"
@@ -14,8 +13,6 @@ type Manager interface {
GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error)
GetAccount(accountId string) ([]*UserData, error)
GetAllAccounts() (map[string][]*UserData, error)
CreateUser(email string, name string, accountID string) (*UserData, error)
GetUserByEmail(email string) ([]*UserData, error)
}
// Config an idp configuration struct to be loaded from management server's config file
@@ -41,18 +38,16 @@ type ManagerHelper interface {
}
type UserData struct {
Email string `json:"email"`
Name string `json:"name"`
ID string `json:"user_id"`
AppMetadata AppMetadata `json:"app_metadata"`
Email string `json:"email"`
Name string `json:"name"`
ID string `json:"user_id"`
}
// AppMetadata user app metadata to associate with a profile
type AppMetadata struct {
// WTAccountID is a NetBird (previously Wiretrustee) account id to update in the IDP
// Wiretrustee account id to update in the IDP
// maps to wt_account_id when json.marshal
WTAccountID string `json:"wt_account_id,omitempty"`
WTPendingInvite *bool `json:"wt_pending_invite"`
WTAccountId string `json:"wt_account_id"`
}
// JWTToken a JWT object that holds information of a token
@@ -65,12 +60,12 @@ type JWTToken struct {
}
// NewManager returns a new idp manager based on the configuration that it receives
func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) {
func NewManager(config Config) (Manager, error) {
switch strings.ToLower(config.ManagerType) {
case "none", "":
return nil, nil
case "auth0":
return NewAuth0Manager(config.Auth0ClientCredentials, appMetrics)
return NewAuth0Manager(config.Auth0ClientCredentials)
default:
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
}

View File

@@ -1,18 +1,6 @@
package idp
import (
"encoding/json"
"math/rand"
"strings"
)
var (
lowerCharSet = "abcdedfghijklmnopqrst"
upperCharSet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
specialCharSet = "!@#$%&*"
numberSet = "0123456789"
allCharSet = lowerCharSet + upperCharSet + specialCharSet + numberSet
)
import "encoding/json"
type JsonParser struct{}
@@ -23,37 +11,3 @@ func (JsonParser) Marshal(v interface{}) ([]byte, error) {
func (JsonParser) Unmarshal(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}
// GeneratePassword generates user password
func GeneratePassword(passwordLength, minSpecialChar, minNum, minUpperCase int) string {
var password strings.Builder
//Set special character
for i := 0; i < minSpecialChar; i++ {
random := rand.Intn(len(specialCharSet))
password.WriteString(string(specialCharSet[random]))
}
//Set numeric
for i := 0; i < minNum; i++ {
random := rand.Intn(len(numberSet))
password.WriteString(string(numberSet[random]))
}
//Set uppercase
for i := 0; i < minUpperCase; i++ {
random := rand.Intn(len(upperCharSet))
password.WriteString(string(upperCharSet[random]))
}
remainingLength := passwordLength - minSpecialChar - minNum - minUpperCase
for i := 0; i < remainingLength; i++ {
random := rand.Intn(len(allCharSet))
password.WriteString(string(allCharSet[random]))
}
inRune := []rune(password.String())
rand.Shuffle(len(inRune), func(i, j int) {
inRune[i], inRune[j] = inRune[j], inRune[i]
})
return string(inRune)
}

View File

@@ -403,12 +403,12 @@ func startManagement(t *testing.T, port int, config *Config) (*grpc.Server, erro
return nil, err
}
peersUpdateManager := NewPeersUpdateManager()
accountManager, err := BuildManager(store, peersUpdateManager, nil, "")
accountManager, err := BuildManager(store, peersUpdateManager, nil)
if err != nil {
return nil, err
}
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager)
if err != nil {
return nil, err
}

View File

@@ -2,6 +2,7 @@ package server_test
import (
"context"
"io/ioutil"
"math/rand"
"net"
"os"
@@ -44,7 +45,7 @@ var _ = Describe("Management service", func() {
level, _ := log.ParseLevel("Debug")
log.SetLevel(level)
var err error
dataDir, err = os.MkdirTemp("", "wiretrustee_mgmt_test_tmp_*")
dataDir, err = ioutil.TempDir("", "wiretrustee_mgmt_test_tmp_*")
Expect(err).NotTo(HaveOccurred())
err = util.CopyFileContents("testdata/store.json", filepath.Join(dataDir, "store.json"))
@@ -493,12 +494,12 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
}
peersUpdateManager := server.NewPeersUpdateManager()
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "")
accountManager, err := server.BuildManager(store, peersUpdateManager, nil)
if err != nil {
log.Fatalf("failed creating a manager: %v", err)
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager)
Expect(err).NotTo(HaveOccurred())
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() {

View File

@@ -1,286 +0,0 @@
// Package metrics gather anonymous information about the usage of NetBird management
package metrics
import (
"context"
"encoding/json"
"fmt"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/server"
log "github.com/sirupsen/logrus"
"io"
"net/http"
"strings"
"time"
)
const (
// PayloadEvent identifies an event type
PayloadEvent = "self-hosted stats"
// payloadEndpoint metrics defaultEndpoint to send anonymous data
payloadEndpoint = "https://metrics.netbird.io"
// defaultPushInterval default interval to push metrics
defaultPushInterval = 24 * time.Hour
// requestTimeout http request timeout
requestTimeout = 30 * time.Second
)
type getTokenResponse struct {
PublicAPIToken string `json:"public_api_token"`
}
type pushPayload struct {
APIKey string `json:"api_key"`
DistinctID string `json:"distinct_id"`
Event string `json:"event"`
Properties properties `json:"properties"`
Timestamp time.Time `json:"timestamp"`
}
// properties metrics to push
type properties map[string]interface{}
// DataSource metric data source
type DataSource interface {
GetAllAccounts() []*server.Account
}
// ConnManager peer connection manager that holds state for current active connections
type ConnManager interface {
GetAllConnectedPeers() map[string]struct{}
}
// Worker metrics collector and pusher
type Worker struct {
ctx context.Context
id string
dataSource DataSource
connManager ConnManager
startupTime time.Time
lastRun time.Time
}
// NewWorker returns a metrics worker
func NewWorker(ctx context.Context, id string, dataSource DataSource, connManager ConnManager) *Worker {
currentTime := time.Now()
return &Worker{
ctx: ctx,
id: id,
dataSource: dataSource,
connManager: connManager,
startupTime: currentTime,
lastRun: currentTime,
}
}
// Run runs the metrics worker
func (w *Worker) Run() {
pushTicker := time.NewTicker(defaultPushInterval)
for {
select {
case <-w.ctx.Done():
return
case <-pushTicker.C:
err := w.sendMetrics()
if err != nil {
log.Error(err)
}
}
}
}
func (w *Worker) sendMetrics() error {
ctx, cancel := context.WithTimeout(w.ctx, requestTimeout)
defer cancel()
apiKey, err := getAPIKey(ctx)
if err != nil {
return err
}
payload := w.generatePayload(apiKey)
payloadString, err := buildMetricsPayload(payload)
if err != nil {
return err
}
httpClient := http.Client{}
exportJobReq, err := createPostRequest(ctx, payloadEndpoint+"/capture/", payloadString)
if err != nil {
return fmt.Errorf("unable to create metrics post request %v", err)
}
jobResp, err := httpClient.Do(exportJobReq)
if err != nil {
return fmt.Errorf("unable to push metrics %v", err)
}
defer func() {
err = jobResp.Body.Close()
if err != nil {
log.Errorf("error while closing update metrics response body: %v", err)
}
}()
if jobResp.StatusCode != 200 {
return fmt.Errorf("unable to push anonymous metrics, got statusCode %d", jobResp.StatusCode)
}
log.Infof("sent anonymous metrics, next push will happen in %s. "+
"You can disable these metrics by running with flag --disable-anonymous-metrics,"+
" see more information at https://netbird.io/docs/FAQ/metrics-collection", defaultPushInterval)
return nil
}
func (w *Worker) generatePayload(apiKey string) pushPayload {
properties := w.generateProperties()
return pushPayload{
APIKey: apiKey,
DistinctID: w.id,
Event: PayloadEvent,
Properties: properties,
Timestamp: time.Now(),
}
}
func (w *Worker) generateProperties() properties {
var (
uptime float64
accounts int
users int
peers int
setupKeysUsage int
activePeersLastDay int
osPeers map[string]int
userPeers int
rules int
groups int
routes int
nameservers int
version string
)
start := time.Now()
metricsProperties := make(properties)
osPeers = make(map[string]int)
uptime = time.Since(w.startupTime).Seconds()
connections := w.connManager.GetAllConnectedPeers()
version = system.NetbirdVersion()
for _, account := range w.dataSource.GetAllAccounts() {
accounts++
users = users + len(account.Users)
rules = rules + len(account.Rules)
groups = groups + len(account.Groups)
routes = routes + len(account.Routes)
nameservers = nameservers + len(account.NameServerGroups)
for _, key := range account.SetupKeys {
setupKeysUsage = setupKeysUsage + key.UsedTimes
}
for _, peer := range account.Peers {
peers++
if peer.SetupKey != "" {
userPeers++
}
osKey := strings.ToLower(fmt.Sprintf("peer_os_%s", peer.Meta.GoOS))
osCount := osPeers[osKey]
osPeers[osKey] = osCount + 1
_, connected := connections[peer.Key]
if connected || peer.Status.LastSeen.After(w.lastRun) {
activePeersLastDay++
osActiveKey := osKey + "_active"
osActiveCount := osPeers[osActiveKey]
osPeers[osActiveKey] = osActiveCount + 1
}
}
}
metricsProperties["uptime"] = uptime
metricsProperties["accounts"] = accounts
metricsProperties["users"] = users
metricsProperties["peers"] = peers
metricsProperties["setup_keys_usage"] = setupKeysUsage
metricsProperties["active_peers_last_day"] = activePeersLastDay
metricsProperties["user_peers"] = userPeers
metricsProperties["rules"] = rules
metricsProperties["groups"] = groups
metricsProperties["routes"] = routes
metricsProperties["nameservers"] = nameservers
metricsProperties["version"] = version
for os, count := range osPeers {
metricsProperties[os] = count
}
metricsProperties["metric_generation_time"] = time.Since(start).Milliseconds()
return metricsProperties
}
func getAPIKey(ctx context.Context) (string, error) {
httpClient := http.Client{}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, payloadEndpoint+"/GetToken", nil)
if err != nil {
return "", fmt.Errorf("unable to create request for metrics public api token %v", err)
}
response, err := httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("unable to request metrics public api token %v", err)
}
defer func() {
err = response.Body.Close()
if err != nil {
log.Errorf("error while closing metrics token response body: %v", err)
}
}()
if response.StatusCode != 200 {
return "", fmt.Errorf("unable to retrieve metrics token, statusCode %d", response.StatusCode)
}
body, err := io.ReadAll(response.Body)
if err != nil {
return "", fmt.Errorf("coudln't get metrics token response; %v", err)
}
var tokenResponse getTokenResponse
err = json.Unmarshal(body, &tokenResponse)
if err != nil {
return "", fmt.Errorf("coudln't parse metrics public api token; %v", err)
}
return tokenResponse.PublicAPIToken, nil
}
func buildMetricsPayload(payload pushPayload) (string, error) {
str, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("unable to marshal metrics payload, got err: %v", err)
}
return string(str), nil
}
func createPostRequest(ctx context.Context, endpoint string, payloadStr string) (*http.Request, error) {
reqURL := endpoint
payload := strings.NewReader(payloadStr)
req, err := http.NewRequestWithContext(ctx, "POST", reqURL, payload)
if err != nil {
return nil, err
}
req.Header.Add("content-type", "application/json")
return req, nil
}

View File

@@ -1,7 +1,6 @@
package mock_server
import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/route"
@@ -11,63 +10,53 @@ import (
)
type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
GetAccountByUserFunc func(userId string) (*server.Account, error)
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string) (*server.SetupKey, error)
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByIdFunc func(accountId string) (*server.Account, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExistsFunc func(accountId string) (*bool, error)
GetPeerFunc func(peerKey string) (*server.Peer, error)
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
MarkPeerConnectedFunc func(peerKey string, connected bool) error
RenamePeerFunc func(accountId string, peerKey string, newName string) (*server.Peer, error)
DeletePeerFunc func(accountId string, peerKey string) (*server.Peer, error)
GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error)
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
AddPeerFunc func(setupKey string, userId string, peer *server.Peer) (*server.Peer, error)
GetGroupFunc func(accountID, groupID string) (*server.Group, error)
SaveGroupFunc func(accountID string, group *server.Group) error
UpdateGroupFunc func(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error)
DeleteGroupFunc func(accountID, groupID string) error
ListGroupsFunc func(accountID string) ([]*server.Group, error)
GroupAddPeerFunc func(accountID, groupID, peerKey string) error
GroupDeletePeerFunc func(accountID, groupID, peerKey string) error
GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error)
GetRuleFunc func(accountID, ruleID, userID string) (*server.Rule, error)
SaveRuleFunc func(accountID string, rule *server.Rule) error
UpdateRuleFunc func(accountID string, ruleID string, operations []server.RuleUpdateOperation) (*server.Rule, error)
DeleteRuleFunc func(accountID, ruleID string) error
ListRulesFunc func(accountID, userID string) ([]*server.Rule, error)
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
UpdatePeerMetaFunc func(peerKey string, meta server.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(peerKey string, sshKey string) error
UpdatePeerFunc func(accountID string, peer *server.Peer) (*server.Peer, error)
CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error)
GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error)
SaveRouteFunc func(accountID string, route *route.Route) error
UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error)
DeleteRouteFunc func(accountID, routeID string) error
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
SaveSetupKeyFunc func(accountID string, key *server.SetupKey) (*server.SetupKey, error)
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(accountID string, user *server.User) (*server.UserInfo, error)
GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroupFunc func(accountID string, nsGroupToSave *nbdns.NameServerGroup) error
UpdateNameServerGroupFunc func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
DeleteNameServerGroupFunc func(accountID, nsGroupID string) error
ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(accountID string, key *server.UserInfo) (*server.UserInfo, error)
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, error)
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
GetAccountByUserFunc func(userId string) (*server.Account, error)
AddSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration) (*server.SetupKey, error)
RevokeSetupKeyFunc func(accountId string, keyId string) (*server.SetupKey, error)
RenameSetupKeyFunc func(accountId string, keyId string, newName string) (*server.SetupKey, error)
GetAccountByIdFunc func(accountId string) (*server.Account, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
GetAccountWithAuthorizationClaimsFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, error)
IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExistsFunc func(accountId string) (*bool, error)
GetPeerFunc func(peerKey string) (*server.Peer, error)
MarkPeerConnectedFunc func(peerKey string, connected bool) error
RenamePeerFunc func(accountId string, peerKey string, newName string) (*server.Peer, error)
DeletePeerFunc func(accountId string, peerKey string) (*server.Peer, error)
GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error)
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
AddPeerFunc func(setupKey string, userId string, peer *server.Peer) (*server.Peer, error)
GetGroupFunc func(accountID, groupID string) (*server.Group, error)
SaveGroupFunc func(accountID string, group *server.Group) error
UpdateGroupFunc func(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error)
DeleteGroupFunc func(accountID, groupID string) error
ListGroupsFunc func(accountID string) ([]*server.Group, error)
GroupAddPeerFunc func(accountID, groupID, peerKey string) error
GroupDeletePeerFunc func(accountID, groupID, peerKey string) error
GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error)
GetRuleFunc func(accountID, ruleID string) (*server.Rule, error)
SaveRuleFunc func(accountID string, rule *server.Rule) error
UpdateRuleFunc func(accountID string, ruleID string, operations []server.RuleUpdateOperation) (*server.Rule, error)
DeleteRuleFunc func(accountID, ruleID string) error
ListRulesFunc func(accountID string) ([]*server.Rule, error)
GetUsersFromAccountFunc func(accountID string) ([]*server.UserInfo, error)
UpdatePeerMetaFunc func(peerKey string, meta server.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(peerKey string, sshKey string) error
UpdatePeerFunc func(accountID string, peer *server.Peer) (*server.Peer, error)
CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error)
GetRouteFunc func(accountID, routeID string) (*route.Route, error)
SaveRouteFunc func(accountID string, route *route.Route) error
UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error)
DeleteRouteFunc func(accountID, routeID string) error
ListRoutesFunc func(accountID string) ([]*route.Route, error)
}
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
func (am *MockAccountManager) GetUsersFromAccount(accountID string, userID string) ([]*server.UserInfo, error) {
func (am *MockAccountManager) GetUsersFromAccount(accountID string) ([]*server.UserInfo, error) {
if am.GetUsersFromAccountFunc != nil {
return am.GetUsersFromAccountFunc(accountID, userID)
return am.GetUsersFromAccountFunc(accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetUsersFromAccount is not implemented")
}
@@ -93,18 +82,40 @@ func (am *MockAccountManager) GetAccountByUser(userId string) (*server.Account,
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUser is not implemented")
}
// CreateSetupKey mock implementation of CreateSetupKey from server.AccountManager interface
func (am *MockAccountManager) CreateSetupKey(
// AddSetupKey mock implementation of AddSetupKey from server.AccountManager interface
func (am *MockAccountManager) AddSetupKey(
accountId string,
keyName string,
keyType server.SetupKeyType,
expiresIn time.Duration,
autoGroups []string,
) (*server.SetupKey, error) {
if am.CreateSetupKeyFunc != nil {
return am.CreateSetupKeyFunc(accountId, keyName, keyType, expiresIn, autoGroups)
if am.AddSetupKeyFunc != nil {
return am.AddSetupKeyFunc(accountId, keyName, keyType, expiresIn)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
return nil, status.Errorf(codes.Unimplemented, "method AddSetupKey is not implemented")
}
// RevokeSetupKey mock implementation of RevokeSetupKey from server.AccountManager interface
func (am *MockAccountManager) RevokeSetupKey(
accountId string,
keyId string,
) (*server.SetupKey, error) {
if am.RevokeSetupKeyFunc != nil {
return am.RevokeSetupKeyFunc(accountId, keyId)
}
return nil, status.Errorf(codes.Unimplemented, "method RevokeSetupKey is not implemented")
}
// RenameSetupKey mock implementation of RenameSetupKey from server.AccountManager interface
func (am *MockAccountManager) RenameSetupKey(
accountId string,
keyId string,
newName string,
) (*server.SetupKey, error) {
if am.RenameSetupKeyFunc != nil {
return am.RenameSetupKeyFunc(accountId, keyId, newName)
}
return nil, status.Errorf(codes.Unimplemented, "method RenameSetupKey is not implemented")
}
// GetAccountById mock implementation of GetAccountById from server.AccountManager interface
@@ -128,6 +139,19 @@ func (am *MockAccountManager) GetAccountByUserOrAccountId(
)
}
// GetAccountWithAuthorizationClaims mock implementation of GetAccountWithAuthorizationClaims from server.AccountManager interface
func (am *MockAccountManager) GetAccountWithAuthorizationClaims(
claims jwtclaims.AuthorizationClaims,
) (*server.Account, error) {
if am.GetAccountWithAuthorizationClaimsFunc != nil {
return am.GetAccountWithAuthorizationClaimsFunc(claims)
}
return nil, status.Errorf(
codes.Unimplemented,
"method GetAccountWithAuthorizationClaims is not implemented",
)
}
// AccountExists mock implementation of AccountExists from server.AccountManager interface
func (am *MockAccountManager) AccountExists(accountId string) (*bool, error) {
if am.AccountExistsFunc != nil {
@@ -273,9 +297,9 @@ func (am *MockAccountManager) GroupListPeers(accountID, groupID string) ([]*serv
}
// GetRule mock implementation of GetRule from server.AccountManager interface
func (am *MockAccountManager) GetRule(accountID, ruleID, userID string) (*server.Rule, error) {
func (am *MockAccountManager) GetRule(accountID, ruleID string) (*server.Rule, error) {
if am.GetRuleFunc != nil {
return am.GetRuleFunc(accountID, ruleID, userID)
return am.GetRuleFunc(accountID, ruleID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetRule is not implemented")
}
@@ -305,9 +329,9 @@ func (am *MockAccountManager) DeleteRule(accountID, ruleID string) error {
}
// ListRules mock implementation of ListRules from server.AccountManager interface
func (am *MockAccountManager) ListRules(accountID, userID string) ([]*server.Rule, error) {
func (am *MockAccountManager) ListRules(accountID string) ([]*server.Rule, error) {
if am.ListRulesFunc != nil {
return am.ListRulesFunc(accountID, userID)
return am.ListRulesFunc(accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListRules is not implemented")
}
@@ -346,16 +370,16 @@ func (am *MockAccountManager) UpdatePeer(accountID string, peer *server.Peer) (*
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
func (am *MockAccountManager) CreateRoute(accountID string, network, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) {
if am.CreateRouteFunc != nil {
if am.GetRouteFunc != nil {
return am.CreateRouteFunc(accountID, network, peer, description, netID, masquerade, metric, enabled)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
}
// GetRoute mock implementation of GetRoute from server.AccountManager interface
func (am *MockAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
func (am *MockAccountManager) GetRoute(accountID, routeID string) (*route.Route, error) {
if am.GetRouteFunc != nil {
return am.GetRouteFunc(accountID, routeID, userID)
return am.GetRouteFunc(accountID, routeID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetRoute is not implemented")
}
@@ -385,116 +409,9 @@ func (am *MockAccountManager) DeleteRoute(accountID, routeID string) error {
}
// ListRoutes mock implementation of ListRoutes from server.AccountManager interface
func (am *MockAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) {
func (am *MockAccountManager) ListRoutes(accountID string) ([]*route.Route, error) {
if am.ListRoutesFunc != nil {
return am.ListRoutesFunc(accountID, userID)
return am.ListRoutesFunc(accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListRoutes is not implemented")
}
// SaveSetupKey mocks SaveSetupKey of the AccountManager interface
func (am *MockAccountManager) SaveSetupKey(accountID string, key *server.SetupKey) (*server.SetupKey, error) {
if am.SaveSetupKeyFunc != nil {
return am.SaveSetupKeyFunc(accountID, key)
}
return nil, status.Errorf(codes.Unimplemented, "method SaveSetupKey is not implemented")
}
// GetSetupKey mocks GetSetupKey of the AccountManager interface
func (am *MockAccountManager) GetSetupKey(accountID, userID, keyID string) (*server.SetupKey, error) {
if am.GetSetupKeyFunc != nil {
return am.GetSetupKeyFunc(accountID, userID, keyID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetSetupKey is not implemented")
}
// ListSetupKeys mocks ListSetupKeys of the AccountManager interface
func (am *MockAccountManager) ListSetupKeys(accountID, userID string) ([]*server.SetupKey, error) {
if am.ListSetupKeysFunc != nil {
return am.ListSetupKeysFunc(accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListSetupKeys is not implemented")
}
// SaveUser mocks SaveUser of the AccountManager interface
func (am *MockAccountManager) SaveUser(accountID string, user *server.User) (*server.UserInfo, error) {
if am.SaveUserFunc != nil {
return am.SaveUserFunc(accountID, user)
}
return nil, status.Errorf(codes.Unimplemented, "method SaveUser is not implemented")
}
// GetNameServerGroup mocks GetNameServerGroup of the AccountManager interface
func (am *MockAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
if am.GetNameServerGroupFunc != nil {
return am.GetNameServerGroupFunc(accountID, nsGroupID)
}
return nil, nil
}
// CreateNameServerGroup mocks CreateNameServerGroup of the AccountManager interface
func (am *MockAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) {
if am.CreateNameServerGroupFunc != nil {
return am.CreateNameServerGroupFunc(accountID, name, description, nameServerList, groups, primary, domains, enabled)
}
return nil, nil
}
// SaveNameServerGroup mocks SaveNameServerGroup of the AccountManager interface
func (am *MockAccountManager) SaveNameServerGroup(accountID string, nsGroupToSave *nbdns.NameServerGroup) error {
if am.SaveNameServerGroupFunc != nil {
return am.SaveNameServerGroupFunc(accountID, nsGroupToSave)
}
return nil
}
// UpdateNameServerGroup mocks UpdateNameServerGroup of the AccountManager interface
func (am *MockAccountManager) UpdateNameServerGroup(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
if am.UpdateNameServerGroupFunc != nil {
return am.UpdateNameServerGroupFunc(accountID, nsGroupID, operations)
}
return nil, nil
}
// DeleteNameServerGroup mocks DeleteNameServerGroup of the AccountManager interface
func (am *MockAccountManager) DeleteNameServerGroup(accountID, nsGroupID string) error {
if am.DeleteNameServerGroupFunc != nil {
return am.DeleteNameServerGroupFunc(accountID, nsGroupID)
}
return nil
}
// ListNameServerGroups mocks ListNameServerGroups of the AccountManager interface
func (am *MockAccountManager) ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error) {
if am.ListNameServerGroupsFunc != nil {
return am.ListNameServerGroupsFunc(accountID)
}
return nil, nil
}
// CreateUser mocks CreateUser of the AccountManager interface
func (am *MockAccountManager) CreateUser(accountID string, invite *server.UserInfo) (*server.UserInfo, error) {
if am.CreateUserFunc != nil {
return am.CreateUserFunc(accountID, invite)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
}
// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface
func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
if am.GetAccountFromTokenFunc != nil {
return am.GetAccountFromTokenFunc(claims)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
}
// GetPeers mocks GetPeers of the AccountManager interface
func (am *MockAccountManager) GetPeers(accountID, userID string) ([]*server.Peer, error) {
if am.GetAccountFromTokenFunc != nil {
return am.GetPeersFunc(accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeersFunc is not implemented")
}

View File

@@ -1,379 +0,0 @@
package server
import (
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/rs/xid"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"strconv"
"unicode/utf8"
)
const (
// UpdateNameServerGroupName indicates a nameserver group name update operation
UpdateNameServerGroupName NameServerGroupUpdateOperationType = iota
// UpdateNameServerGroupDescription indicates a nameserver group description update operation
UpdateNameServerGroupDescription
// UpdateNameServerGroupNameServers indicates a nameserver group nameservers list update operation
UpdateNameServerGroupNameServers
// UpdateNameServerGroupGroups indicates a nameserver group' groups update operation
UpdateNameServerGroupGroups
// UpdateNameServerGroupEnabled indicates a nameserver group status update operation
UpdateNameServerGroupEnabled
// UpdateNameServerGroupPrimary indicates a nameserver group primary status update operation
UpdateNameServerGroupPrimary
// UpdateNameServerGroupDomains indicates a nameserver group' domains update operation
UpdateNameServerGroupDomains
)
// NameServerGroupUpdateOperationType operation type
type NameServerGroupUpdateOperationType int
func (t NameServerGroupUpdateOperationType) String() string {
switch t {
case UpdateNameServerGroupDescription:
return "UpdateNameServerGroupDescription"
case UpdateNameServerGroupName:
return "UpdateNameServerGroupName"
case UpdateNameServerGroupNameServers:
return "UpdateNameServerGroupNameServers"
case UpdateNameServerGroupGroups:
return "UpdateNameServerGroupGroups"
case UpdateNameServerGroupEnabled:
return "UpdateNameServerGroupEnabled"
case UpdateNameServerGroupPrimary:
return "UpdateNameServerGroupPrimary"
case UpdateNameServerGroupDomains:
return "UpdateNameServerGroupDomains"
default:
return "InvalidOperation"
}
}
// NameServerGroupUpdateOperation operation object with type and values to be applied
type NameServerGroupUpdateOperation struct {
Type NameServerGroupUpdateOperationType
Values []string
}
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
nsGroup, found := account.NameServerGroups[nsGroupID]
if found {
return nsGroup.Copy(), nil
}
return nil, status.Errorf(codes.NotFound, "nameserver group with ID %s not found", nsGroupID)
}
// CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
newNSGroup := &nbdns.NameServerGroup{
ID: xid.New().String(),
Name: name,
Description: description,
NameServers: nameServerList,
Groups: groups,
Enabled: enabled,
Primary: primary,
Domains: domains,
}
err = validateNameServerGroup(false, newNSGroup, account)
if err != nil {
return nil, err
}
if account.NameServerGroups == nil {
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup)
}
account.NameServerGroups[newNSGroup.ID] = newNSGroup
account.Network.IncSerial()
err = am.Store.SaveAccount(account)
if err != nil {
return nil, err
}
return newNSGroup.Copy(), nil
}
// SaveNameServerGroup saves nameserver group
func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupToSave *nbdns.NameServerGroup) error {
am.mux.Lock()
defer am.mux.Unlock()
if nsGroupToSave == nil {
return status.Errorf(codes.InvalidArgument, "nameserver group provided is nil")
}
account, err := am.Store.GetAccount(accountID)
if err != nil {
return status.Errorf(codes.NotFound, "account not found")
}
err = validateNameServerGroup(true, nsGroupToSave, account)
if err != nil {
return err
}
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave
account.Network.IncSerial()
err = am.Store.SaveAccount(account)
if err != nil {
return err
}
return nil
}
// UpdateNameServerGroup updates existing nameserver group with set of operations
func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
if len(operations) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "operations shouldn't be empty")
}
nsGroupToUpdate, ok := account.NameServerGroups[nsGroupID]
if !ok {
return nil, status.Errorf(codes.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
}
newNSGroup := nsGroupToUpdate.Copy()
for _, operation := range operations {
valuesCount := len(operation.Values)
if valuesCount < 1 {
return nil, status.Errorf(codes.InvalidArgument, "operation %s contains invalid number of values, it should be at least 1", operation.Type.String())
}
for _, value := range operation.Values {
if value == "" {
return nil, status.Errorf(codes.InvalidArgument, "operation %s contains invalid empty string value", operation.Type.String())
}
}
switch operation.Type {
case UpdateNameServerGroupDescription:
newNSGroup.Description = operation.Values[0]
case UpdateNameServerGroupName:
if valuesCount > 1 {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse name values, expected 1 value got %d", valuesCount)
}
err = validateNSGroupName(operation.Values[0], nsGroupID, account.NameServerGroups)
if err != nil {
return nil, err
}
newNSGroup.Name = operation.Values[0]
case UpdateNameServerGroupNameServers:
var nsList []nbdns.NameServer
for _, url := range operation.Values {
ns, err := nbdns.ParseNameServerURL(url)
if err != nil {
return nil, err
}
nsList = append(nsList, ns)
}
err = validateNSList(nsList)
if err != nil {
return nil, err
}
newNSGroup.NameServers = nsList
case UpdateNameServerGroupGroups:
err = validateGroups(operation.Values, account.Groups)
if err != nil {
return nil, err
}
newNSGroup.Groups = operation.Values
case UpdateNameServerGroupEnabled:
enabled, err := strconv.ParseBool(operation.Values[0])
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
}
newNSGroup.Enabled = enabled
case UpdateNameServerGroupPrimary:
primary, err := strconv.ParseBool(operation.Values[0])
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse primary status %s, not boolean", operation.Values[0])
}
newNSGroup.Primary = primary
case UpdateNameServerGroupDomains:
err = validateDomainInput(false, operation.Values)
if err != nil {
return nil, err
}
newNSGroup.Domains = operation.Values
}
}
account.NameServerGroups[nsGroupID] = newNSGroup
account.Network.IncSerial()
err = am.Store.SaveAccount(account)
if err != nil {
return nil, err
}
return newNSGroup.Copy(), nil
}
// DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID string) error {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return status.Errorf(codes.NotFound, "account not found")
}
delete(account.NameServerGroups, nsGroupID)
account.Network.IncSerial()
err = am.Store.SaveAccount(account)
if err != nil {
return err
}
return nil
}
// ListNameServerGroups returns a list of nameserver groups from account
func (am *DefaultAccountManager) ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups))
for _, item := range account.NameServerGroups {
nsGroups = append(nsGroups, item.Copy())
}
return nsGroups, nil
}
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {
nsGroupID := ""
if existingGroup {
nsGroupID = nameserverGroup.ID
_, found := account.NameServerGroups[nsGroupID]
if !found {
return status.Errorf(codes.NotFound, "nameserver group with ID %s was not found", nsGroupID)
}
}
err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains)
if err != nil {
return err
}
err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups)
if err != nil {
return err
}
err = validateNSList(nameserverGroup.NameServers)
if err != nil {
return err
}
err = validateGroups(nameserverGroup.Groups, account.Groups)
if err != nil {
return err
}
return nil
}
func validateDomainInput(primary bool, domains []string) error {
if !primary && len(domains) == 0 {
return status.Errorf(codes.InvalidArgument, "nameserver group primary status is false and domains are empty,"+
" it should be primary or have at least one domain")
}
if primary && len(domains) != 0 {
return status.Errorf(codes.InvalidArgument, "nameserver group primary status is true and domains are not empty,"+
" you should set either primary or domain")
}
for _, domain := range domains {
_, valid := dns.IsDomainName(domain)
if !valid {
return status.Errorf(codes.InvalidArgument, "nameserver group got an invalid domain: %s", domain)
}
}
return nil
}
func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error {
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
return status.Errorf(codes.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
}
for _, nsGroup := range nsGroupMap {
if name == nsGroup.Name && nsGroup.ID != nsGroupID {
return status.Errorf(codes.InvalidArgument, "a nameserver group with name %s already exist", name)
}
}
return nil
}
func validateNSList(list []nbdns.NameServer) error {
nsListLenght := len(list)
if nsListLenght == 0 || nsListLenght > 2 {
return status.Errorf(codes.InvalidArgument, "the list of nameservers should be 1 or 2, got %d", len(list))
}
return nil
}
func validateGroups(list []string, groups map[string]*Group) error {
if len(list) == 0 {
return status.Errorf(codes.InvalidArgument, "the list of group IDs should not be empty")
}
for _, id := range list {
if id == "" {
return status.Errorf(codes.InvalidArgument, "group ID should not be empty string")
}
found := false
for groupID := range groups {
if id == groupID {
found = true
break
}
}
if !found {
return status.Errorf(codes.InvalidArgument, "group id %s not found", id)
}
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -81,33 +81,6 @@ func (am *DefaultAccountManager) GetPeer(peerKey string) (*Peer, error) {
return peer, nil
}
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
// the current user is not an admin.
func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
peers := make([]*Peer, 0, len(account.Peers))
for _, peer := range account.Peers {
if !user.IsAdmin() && user.Id != peer.UserID {
// only display peers that belong to the current user if the current user is not an admin
continue
}
peers = append(peers, peer.Copy())
}
return peers, nil
}
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
func (am *DefaultAccountManager) MarkPeerConnected(peerKey string, connected bool) error {
am.mux.Lock()
@@ -321,8 +294,6 @@ func (am *DefaultAccountManager) AddPeer(
var account *Account
var err error
var sk *SetupKey
// auto-assign groups that are coming with a SetupKey or a User
var groupsToAdd []string
if len(upperKey) != 0 {
account, err = am.Store.GetAccountBySetupKey(upperKey)
if err != nil {
@@ -350,20 +321,11 @@ func (am *DefaultAccountManager) AddPeer(
)
}
groupsToAdd = sk.AutoGroups
} else if len(userID) != 0 {
account, err = am.Store.GetUserAccount(userID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "unable to register peer, unknown user with ID: %s", userID)
}
user, ok := account.Users[userID]
if !ok {
return nil, status.Errorf(codes.NotFound, "unable to register peer, unknown user with ID: %s", userID)
}
groupsToAdd = user.AutoGroups
} else {
// Empty setup key and jwt fail
return nil, status.Errorf(codes.InvalidArgument, "no setup key or user id provided")
@@ -399,14 +361,6 @@ func (am *DefaultAccountManager) AddPeer(
}
group.Peers = append(group.Peers, newPeer.Key)
if len(groupsToAdd) > 0 {
for _, s := range groupsToAdd {
if g, ok := account.Groups[s]; ok && g.Name != "All" {
g.Peers = append(g.Peers, newPeer.Key)
}
}
}
account.Peers[newPeer.Key] = newPeer
if len(upperKey) != 0 {
account.SetupKeys[sk.Key] = sk.IncrementUsage()

View File

@@ -134,7 +134,7 @@ func TestAccountManager_GetNetworkMapWithRule(t *testing.T) {
return
}
rules, err := manager.ListRules(account.Id, userId)
rules, err := manager.ListRules(account.Id)
if err != nil {
t.Errorf("expecting to get a list of rules, got failure %v", err)
return

View File

@@ -60,7 +60,7 @@ type RouteUpdateOperation struct {
}
// GetRoute gets a route object from account and route IDs
func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
func (am *DefaultAccountManager) GetRoute(accountID, routeID string) (*route.Route, error) {
am.mux.Lock()
defer am.mux.Unlock()
@@ -69,15 +69,6 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*r
return nil, status.Errorf(codes.NotFound, "account not found")
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !user.IsAdmin() {
return nil, Errorf(PermissionDenied, "Only administrators can view Network Routes")
}
wantedRoute, found := account.Routes[routeID]
if found {
return wantedRoute, nil
@@ -334,7 +325,7 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID string) error {
}
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) {
func (am *DefaultAccountManager) ListRoutes(accountID string) ([]*route.Route, error) {
am.mux.Lock()
defer am.mux.Unlock()
@@ -343,15 +334,6 @@ func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.
return nil, status.Errorf(codes.NotFound, "account not found")
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !user.IsAdmin() {
return nil, Errorf(PermissionDenied, "Only administrators can view Network Routes")
}
routes := make([]*route.Route, 0, len(account.Routes))
for _, item := range account.Routes {
routes = append(routes, item)

Some files were not shown because too many files have changed in this diff Show More