Compare commits

..

5 Commits

Author SHA1 Message Date
braginini
03d0f62ccd Add GroupMinimum to the SetupKey response 2022-09-12 14:30:46 +02:00
Misha Bragin
be7d829858 Add SetupKey auto-groups property (#460) 2022-09-11 23:16:40 +02:00
Maycon Santos
ed1872560f Use the client network for log errors (#455) 2022-09-07 18:26:59 +02:00
Maycon Santos
de898899a4 update slack invite tittle 2022-09-05 18:44:04 +02:00
Maycon Santos
b63ec71aed Check if login stream was canceled before printing warn (#451) 2022-09-05 17:44:26 +02:00
30 changed files with 892 additions and 1432 deletions

View File

@@ -16,7 +16,7 @@
<a href="https://www.codacy.com/gh/netbirdio/netbird/dashboard?utm_source=github.com&amp;utm_medium=referral&amp;utm_content=netbirdio/netbird&amp;utm_campaign=Badge_Grade"><img src="https://app.codacy.com/project/badge/Grade/e3013d046aec44cdb7462c8673b00976"/></a>
<br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">
<img src="https://img.shields.io/badge/slack-@wiretrustee-red.svg?logo=slack"/>
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
</p>
</div>
@@ -43,20 +43,20 @@ It requires zero configuration effort leaving behind the hassle of opening ports
NetBird creates an overlay peer-to-peer network connecting machines automatically regardless of their location (home, office, datacenter, container, cloud or edge environments) unifying virtual private network management experience.
**Key features:**
- \[x] Automatic IP allocation and network management with a Web UI ([separate repo](https://github.com/netbirdio/dashboard))
- \[x] Automatic WireGuard peer (machine) discovery and configuration.
- \[x] Encrypted peer-to-peer connections without a central VPN gateway.
- \[x] Connection relay fallback in case a peer-to-peer connection is not possible.
- \[x] Desktop client applications for Linux, MacOS, and Windows (systray).
- \[x] Multiuser support - sharing network between multiple users.
- \[x] SSO and MFA support.
- \[x] Multicloud and hybrid-cloud support.
- \[x] Kernel WireGuard usage when possible.
- \[x] Access Controls - groups & rules.
- \[x] Remote SSH access without managing SSH keys.
- \[x] Automatic IP allocation and network management with a Web UI ([separate repo](https://github.com/netbirdio/dashboard))
- \[x] Automatic WireGuard peer (machine) discovery and configuration.
- \[x] Encrypted peer-to-peer connections without a central VPN gateway.
- \[x] Connection relay fallback in case a peer-to-peer connection is not possible.
- \[x] Desktop client applications for Linux, MacOS, and Windows (systray).
- \[x] Multiuser support - sharing network between multiple users.
- \[x] SSO and MFA support.
- \[x] Multicloud and hybrid-cloud support.
- \[x] Kernel WireGuard usage when possible.
- \[x] Access Controls - groups & rules.
- \[x] Remote SSH access without managing SSH keys.
- \[x] Network Routes.
**Coming soon:**
- \[ ] Network Routes.
- \[ ] Private DNS.
- \[ ] Mobile clients.
- \[ ] Network Activity Monitoring.

View File

@@ -1,98 +0,0 @@
package main
/*
import (
"flag"
"github.com/netbirdio/netbird/iface"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"net"
"net/http"
_ "net/http/pprof"
"time"
)
var name = flag.String("name", "wg0", "WireGuard interface name")
var addr = flag.String("addr", "100.64.0.1/24", "interface WireGuard IP addr")
var key = flag.String("key", "100.64.0.1/24", "WireGuard private key")
var port = flag.Int("port", 51820, "WireGuard port")
var remoteKey = flag.String("remote-key", "", "remote WireGuard public key")
var remoteAddr = flag.String("remote-addr", "100.64.0.2/32", "remote WireGuard IP addr")
var remoteEndpoint = flag.String("remote-endpoint", "127.0.0.1:51820", "remote WireGuard endpoint")
func fff() {
flag.Parse()
go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
}()
myKey, err := wgtypes.ParseKey(*key)
if err != nil {
log.Error(err)
return
}
log.Infof("public key and addr [%s] [%s] ", myKey.PublicKey().String(), *addr)
wgIFace, err := iface.NewWGIFace(*name, *addr, 1280)
if err != nil {
log.Error(err)
return
}
defer wgIFace.Close()
// todo wrap into UDPMux
sharedSock, _, err := listenNet("udp4", *port)
if err != nil {
log.Error(err)
return
}
defer sharedSock.Close()
// err = wgIFace.Create()
err = wgIFace.CreateNew(sharedSock)
if err != nil {
log.Errorf("failed to create interface %s %v", *name, err)
return
}
err = wgIFace.Configure(*key, *port)
if err != nil {
log.Errorf("failed to configure interface %s %v", *name, err)
return
}
ip, err := net.ResolveUDPAddr("udp4", *remoteEndpoint)
if err != nil {
// handle error
}
err = wgIFace.UpdatePeer(*remoteKey, *remoteAddr, 20*time.Second, ip, nil)
if err != nil {
log.Errorf("failed to configure remote peer %s %v", *remoteKey, err)
return
}
select {}
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
if err != nil {
return nil, 0, err
}
// Retrieve port.
laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr(
laddr.Network(),
laddr.String(),
)
if err != nil {
return nil, 0, err
}
return conn, uaddr.Port, nil
}*/

View File

@@ -8,6 +8,7 @@ import (
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/route"
"math/rand"
"net"
"reflect"
"runtime"
"strings"
@@ -88,7 +89,10 @@ type Engine struct {
wgInterface *iface.WGIface
iceMux ice.UniversalUDPMux
udpMux ice.UDPMux
udpMuxSrflx ice.UniversalUDPMux
udpMuxConn *net.UDPConn
udpMuxConnSrflx *net.UDPConn
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64
@@ -151,6 +155,30 @@ func (e *Engine) Stop() error {
}
}
if e.udpMux != nil {
if err := e.udpMux.Close(); err != nil {
log.Debugf("close udp mux: %v", err)
}
}
if e.udpMuxSrflx != nil {
if err := e.udpMuxSrflx.Close(); err != nil {
log.Debugf("close server reflexive udp mux: %v", err)
}
}
if e.udpMuxConn != nil {
if err := e.udpMuxConn.Close(); err != nil {
log.Debugf("close udp mux connection: %v", err)
}
}
if e.udpMuxConnSrflx != nil {
if err := e.udpMuxConnSrflx.Close(); err != nil {
log.Debugf("close server reflexive udp mux connection: %v", err)
}
}
if !isNil(e.sshServer) {
err := e.sshServer.Stop()
if err != nil {
@@ -185,34 +213,34 @@ func (e *Engine) Start() error {
return err
}
bind := &iface.ICEBind{}
err = e.wgInterface.CreateNew(bind)
e.udpMuxConn, err = net.ListenUDP("udp4", &net.UDPAddr{Port: e.config.UDPMuxPort})
if err != nil {
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error())
return err
}
e.udpMuxConnSrflx, err = net.ListenUDP("udp4", &net.UDPAddr{Port: e.config.UDPMuxSrflxPort})
if err != nil {
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxSrflxPort, err.Error())
return err
}
e.udpMux = ice.NewUDPMuxDefault(ice.UDPMuxParams{UDPConn: e.udpMuxConn})
e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx})
err = e.wgInterface.Create()
if err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", wgIfaceName, err.Error())
return err
}
port, err := e.wgInterface.GetListenPort()
if err != nil {
return err
}
err = e.wgInterface.Configure(myPrivateKey.String(), *port)
err = e.wgInterface.Configure(myPrivateKey.String(), e.config.WgPort)
if err != nil {
log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIfaceName, err.Error())
return err
}
iceMux, err := bind.GetICEMux()
if err != nil {
return err
}
e.iceMux = iceMux
log.Infof("NetBird Engine started listening on WireGuard port %d", *port)
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder)
e.config.WgPort = *port
e.receiveSignalEvents()
e.receiveManagementEvents()
@@ -732,8 +760,8 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
StunTurn: stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
Timeout: timeout,
UDPMux: e.iceMux,
UDPMuxSrflx: e.iceMux,
UDPMux: e.udpMux,
UDPMuxSrflx: e.udpMuxSrflx,
ProxyConfig: proxyConfig,
LocalWgPort: e.config.WgPort,
}

View File

@@ -36,7 +36,10 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string
defer func() {
err = mgmClient.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err)
cStatus, ok := status.FromError(err)
if !ok || ok && cStatus.Code() != codes.Canceled {
log.Warnf("failed to close the Management service client, err: %v", err)
}
}
}()

View File

@@ -147,7 +147,7 @@ func (conn *Conn) reCreateAgent() error {
MulticastDNSMode: ice.MulticastDNSModeDisabled,
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4},
Urls: conn.config.StunTurn,
CandidateTypes: []ice.CandidateType{ice.CandidateTypeServerReflexive, ice.CandidateTypeHost, ice.CandidateTypeRelay},
CandidateTypes: []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay},
FailedTimeout: &failedTimeout,
InterfaceFilter: interfaceFilter(conn.config.InterfaceBlackList),
UDPMux: conn.config.UDPMux,
@@ -280,7 +280,14 @@ func (conn *Conn) Open() error {
return err
}
log.Infof("connected to peer %s [laddr <-> raddr] [%s <-> %s]", conn.config.Key, remoteConn.LocalAddr().String(), remoteConn.RemoteAddr().String())
if conn.proxy.Type() == proxy.TypeNoProxy {
host, _, _ := net.SplitHostPort(remoteConn.LocalAddr().String())
rhost, _, _ := net.SplitHostPort(remoteConn.RemoteAddr().String())
// direct Wireguard connection
log.Infof("directly connected to peer %s [laddr <-> raddr] [%s:%d <-> %s:%d]", conn.config.Key, host, iface.DefaultWgPort, rhost, iface.DefaultWgPort)
} else {
log.Infof("connected to peer %s [laddr <-> raddr] [%s <-> %s]", conn.config.Key, remoteConn.LocalAddr().String(), remoteConn.RemoteAddr().String())
}
// wait until connection disconnected or has been closed externally (upper layer, e.g. engine)
select {
@@ -344,16 +351,15 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
}
peerState := nbStatus.PeerState{PubKey: conn.config.Key}
useProxy := shouldUseProxy(pair)
var p proxy.Proxy
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
if useProxy {
p = proxy.NewWireguardProxy(conn.config.ProxyConfig)
peerState.Direct = false
} else {
p = proxy.NewNoProxy(conn.config.ProxyConfig, remoteWgPort)
peerState.Direct = true
}
conn.proxy = p
err = p.Start(remoteConn)
if err != nil {

View File

@@ -39,6 +39,7 @@ func (p *NoProxy) Start(remoteConn net.Conn) error {
if err != nil {
return err
}
addr.Port = p.RemoteWgListenPort
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
addr, p.config.PreSharedKey)

View File

@@ -207,7 +207,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
err = addToRouteTableIfNoExists(c.network, c.wgInterface.GetAddress().IP.String())
if err != nil {
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
c.chosenRoute.Network.String(), c.wgInterface.GetAddress().IP.String(), err)
c.network.String(), c.wgInterface.GetAddress().IP.String(), err)
}
}

9
go.mod
View File

@@ -39,13 +39,9 @@ require (
github.com/libp2p/go-netroute v0.2.0
github.com/magiconair/properties v1.8.5
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pion/logging v0.2.2
github.com/pion/stun v0.3.5
github.com/pion/transport v0.13.0
github.com/rs/xid v1.3.0
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/stretchr/testify v1.7.1
go.uber.org/zap v1.17.0
golang.org/x/net v0.0.0-20220513224357-95641704303c
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467
)
@@ -85,8 +81,11 @@ require (
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
github.com/pegasus-kv/thrift v0.13.0 // indirect
github.com/pion/dtls/v2 v2.1.2 // indirect
github.com/pion/logging v0.2.2 // indirect
github.com/pion/mdns v0.0.5 // indirect
github.com/pion/randutil v0.1.0 // indirect
github.com/pion/stun v0.3.5 // indirect
github.com/pion/transport v0.13.0 // indirect
github.com/pion/turn/v2 v2.0.7 // indirect
github.com/pion/udp v0.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
@@ -100,8 +99,6 @@ require (
github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect
github.com/yuin/goldmark v1.4.1 // indirect
go.uber.org/atomic v1.7.0 // indirect
go.uber.org/multierr v1.6.0 // indirect
golang.org/x/image v0.0.0-20200430140353-33d19683fad8 // indirect
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect

3
go.sum
View File

@@ -646,11 +646,8 @@ go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI=
go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4=
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
go.uber.org/zap v1.17.0 h1:MTjgFu6ZLKvY6Pvaqk97GlxNBuMpV4Hy/3P6tRGlI2U=
go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=

View File

@@ -1,185 +0,0 @@
package iface
import (
"errors"
"fmt"
"github.com/pion/stun"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
"net"
"net/netip"
"sync"
"syscall"
)
type ICEBind struct {
sharedConn net.PacketConn
udpMux *UniversalUDPMuxDefault
iceHostMux *UDPMuxDefault
mu sync.Mutex // protects following fields
}
func (b *ICEBind) GetICEMux() (UniversalUDPMux, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.udpMux == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
return b.udpMux, nil
}
func (b *ICEBind) GetICEHostMux() (UDPMux, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.iceHostMux == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
return b.iceHostMux, nil
}
func (b *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.sharedConn != nil {
return nil, 0, conn.ErrBindAlreadyOpen
}
port := int(uport)
ipv4Conn, port, err := listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
b.sharedConn = ipv4Conn
b.udpMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: b.sharedConn})
portAddr1, err := netip.ParseAddrPort(ipv4Conn.LocalAddr().String())
if err != nil {
return nil, 0, err
}
log.Infof("opened ICEBind on %s", ipv4Conn.LocalAddr().String())
return []conn.ReceiveFunc{
b.makeReceiveIPv4(b.sharedConn),
},
portAddr1.Port(), nil
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
if err != nil {
return nil, 0, err
}
// Retrieve port.
laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr(
laddr.Network(),
laddr.String(),
)
if err != nil {
return nil, 0, err
}
return conn, uaddr.Port, nil
}
func parseSTUNMessage(raw []byte) (*stun.Message, error) {
msg := &stun.Message{
Raw: append([]byte{}, raw...),
}
if err := msg.Decode(); err != nil {
return nil, err
}
return msg, nil
}
func (b *ICEBind) makeReceiveIPv4(c net.PacketConn) conn.ReceiveFunc {
return func(buff []byte) (int, conn.Endpoint, error) {
n, endpoint, err := c.ReadFrom(buff)
if err != nil {
return 0, nil, err
}
e, err := netip.ParseAddrPort(endpoint.String())
if err != nil {
return 0, nil, err
}
if !stun.IsMessage(buff[:20]) {
// WireGuard traffic
return n, (*conn.StdNetEndpoint)(&net.UDPAddr{
IP: e.Addr().AsSlice(),
Port: int(e.Port()),
Zone: e.Addr().Zone(),
}), nil
}
msg, err := parseSTUNMessage(buff[:n])
if err != nil {
return 0, nil, err
}
err = b.udpMux.HandleSTUNMessage(msg, endpoint)
if err != nil {
return 0, nil, err
}
if err != nil {
log.Warnf("failed to handle packet")
}
// discard packets because they are STUN related
return 0, nil, nil //todo proper return
}
}
func (b *ICEBind) Close() error {
b.mu.Lock()
defer b.mu.Unlock()
var err1, err2 error
if b.sharedConn != nil {
c := b.sharedConn
b.sharedConn = nil
err1 = c.Close()
}
if b.udpMux != nil {
m := b.udpMux
b.udpMux = nil
err2 = m.Close()
}
if err1 != nil {
return err1
}
return err2
}
// SetMark sets the mark for each packet sent through this Bind.
// This mark is passed to the kernel as the socket option SO_MARK.
func (b *ICEBind) SetMark(mark uint32) error {
return nil
}
func (b *ICEBind) Send(buff []byte, endpoint conn.Endpoint) error {
nend, ok := endpoint.(*conn.StdNetEndpoint)
if !ok {
return conn.ErrWrongEndpointType
}
_, err := b.sharedConn.WriteTo(buff, (*net.UDPAddr)(nend))
return err
}
// ParseEndpoint creates a new endpoint from a string.
func (b *ICEBind) ParseEndpoint(s string) (ep conn.Endpoint, err error) {
e, err := netip.ParseAddrPort(s)
return (*conn.StdNetEndpoint)(&net.UDPAddr{
IP: e.Addr().AsSlice(),
Port: int(e.Port()),
Zone: e.Addr().Zone(),
}), err
}

View File

@@ -55,6 +55,7 @@ func (w *WGIface) Configure(privateKey string, port int) error {
PrivateKey: &key,
ReplacePeers: true,
FirewallMark: &fwmark,
ListenPort: &port,
}
err = w.configureDevice(config)

View File

@@ -2,10 +2,6 @@ package iface
import (
"fmt"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"net"
"os"
"runtime"
@@ -25,7 +21,6 @@ type WGIface struct {
Address WGAddress
Interface NetInterface
mu sync.Mutex
Bind *ICEBind
}
// WGAddress Wireguard parsed address
@@ -96,49 +91,3 @@ func (w *WGIface) Close() error {
return nil
}
func (w *WGIface) CreateNew(bind conn.Bind) error {
w.mu.Lock()
defer w.mu.Unlock()
return w.createWithUserspaceNew(bind)
}
func (w *WGIface) createWithUserspaceNew(bind conn.Bind) error {
tunIface, err := tun.CreateTUN(w.Name, w.MTU)
if err != nil {
return err
}
w.Interface = tunIface
// We need to create a wireguard-go device and listen to configuration requests
tunDevice := device.NewDevice(tunIface, bind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
err = tunDevice.Up()
if err != nil {
return err
}
uapi, err := getUAPI(w.Name)
if err != nil {
return err
}
go func() {
for {
uapiConn, uapiErr := uapi.Accept()
if uapiErr != nil {
log.Traceln("uapi Accept failed with error: ", uapiErr)
continue
}
go tunDevice.IpcHandle(uapiConn)
}
}()
log.Debugln("UAPI listener started")
err = w.assignAddr()
if err != nil {
return err
}
return nil
}

View File

@@ -39,7 +39,13 @@ func (w *WGIface) Create() error {
w.mu.Lock()
defer w.mu.Unlock()
return w.createWithUserspace()
if WireguardModExists() {
log.Info("using kernel WireGuard")
return w.createWithKernel()
} else {
log.Info("using userspace WireGuard")
return w.createWithUserspace()
}
}
// createWithKernel Creates a new Wireguard interface using kernel Wireguard module.

View File

@@ -4,7 +4,6 @@ import (
"fmt"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/windows/driver"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"net"
@@ -63,8 +62,3 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
func WireguardModExists() bool {
return false
}
// getUAPI returns a Listener
func getUAPI(iface string) (net.Listener, error) {
return ipc.UAPIListen(iface)
}

View File

@@ -1,288 +0,0 @@
package iface
import (
"fmt"
log "github.com/sirupsen/logrus"
"io"
"net"
"strings"
"sync"
"github.com/pion/logging"
"github.com/pion/stun"
)
const receiveMTU = 8192
// UDPMux allows multiple connections to go over a single UDP port
type UDPMux interface {
io.Closer
GetConn(ufrag string) (net.PacketConn, error)
RemoveConnByUfrag(ufrag string)
}
// UDPMuxDefault is an implementation of the interface
type UDPMuxDefault struct {
params UDPMuxParams
closedChan chan struct{}
closeOnce sync.Once
// conns is a map of all udpMuxedConn indexed by ufrag|network|candidateType
conns map[string]*udpMuxedConn
addressMapMu sync.RWMutex
addressMap map[string][]*udpMuxedConn
// buffer pool to recycle buffers for net.UDPAddr encodes/decodes
pool *sync.Pool
mu sync.Mutex
}
const maxAddrSize = 512
// UDPMuxParams are parameters for UDPMux.
type UDPMuxParams struct {
Logger logging.LeveledLogger
UDPConn net.PacketConn
}
// NewUDPMuxDefault creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}
return &UDPMuxDefault{
addressMap: map[string][]*udpMuxedConn{},
params: params,
conns: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1),
pool: &sync.Pool{
New: func() interface{} {
// big enough buffer to fit both packet and address
return newBufferHolder(receiveMTU + maxAddrSize)
},
},
}
}
func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
remoteAddr, ok := addr.(*net.UDPAddr)
if !ok {
return fmt.Errorf("underlying PacketConn did not return a UDPAddr")
}
// If we have already seen this address dispatch to the appropriate destination
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections.
m.addressMapMu.Lock()
var destinationConnList []*udpMuxedConn
if storedConns, ok := m.addressMap[addr.String()]; ok {
for _, conn := range storedConns {
destinationConnList = append(destinationConnList, conn)
}
}
m.addressMapMu.Unlock()
// This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront.
// However, we can take a username attribute from the STUN message which contains ufrag.
// We can use ufrag to identify the destination conn to route packet to.
attr, stunAttrErr := msg.Get(stun.AttrUsername)
if stunAttrErr == nil {
ufrag := strings.Split(string(attr), ":")[0]
m.mu.Lock()
if destinationConn, ok := m.conns[ufrag]; ok {
exists := false
for _, conn := range destinationConnList {
if conn.params.Key == destinationConn.params.Key {
exists = true
break
}
}
if !exists {
destinationConnList = append(destinationConnList, destinationConn)
}
}
m.mu.Unlock()
}
// Forward STUN packets to each destination connections even thought the STUN packet might not belong there.
// It will be discarded by the further ICE candidate logic if so.
for _, conn := range destinationConnList {
if err := conn.writePacket(msg.Raw, remoteAddr); err != nil {
log.Errorf("could not write packet: %v", err)
}
}
return nil
}
// LocalAddr returns the listening address of this UDPMuxDefault
func (m *UDPMuxDefault) LocalAddr() net.Addr {
return m.params.UDPConn.LocalAddr()
}
// GetConn returns a PacketConn given the connection's ufrag and network
// creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) {
m.mu.Lock()
defer m.mu.Unlock()
log.Debugf("ICE: getting muxed connection for %s", ufrag)
if m.IsClosed() {
return nil, io.ErrClosedPipe
}
if c, ok := m.conns[ufrag]; ok {
return c, nil
}
c := m.createMuxedConn(ufrag)
go func() {
<-c.CloseChannel()
m.removeConn(ufrag)
}()
m.conns[ufrag] = c
return c, nil
}
// RemoveConnByUfrag stops and removes the muxed packet connection
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
m.mu.Lock()
removedConns := make([]*udpMuxedConn, 0)
for key := range m.conns {
if key != ufrag {
continue
}
c := m.conns[key]
delete(m.conns, key)
if c != nil {
removedConns = append(removedConns, c)
}
}
// keep lock section small to avoid deadlock with conn lock
m.mu.Unlock()
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
for _, c := range removedConns {
addresses := c.getAddresses()
for _, addr := range addresses {
if connList, ok := m.addressMap[addr]; ok {
var newList []*udpMuxedConn
for _, conn := range connList {
if conn.params.Key != ufrag {
newList = append(newList, conn)
}
}
m.addressMap[addr] = newList
}
}
}
}
// IsClosed returns true if the mux had been closed
func (m *UDPMuxDefault) IsClosed() bool {
select {
case <-m.closedChan:
return true
default:
return false
}
}
// Close the mux, no further connections could be created
func (m *UDPMuxDefault) Close() error {
var err error
m.closeOnce.Do(func() {
m.mu.Lock()
defer m.mu.Unlock()
for _, c := range m.conns {
_ = c.Close()
}
m.conns = make(map[string]*udpMuxedConn)
close(m.closedChan)
})
return err
}
func (m *UDPMuxDefault) removeConn(key string) {
m.mu.Lock()
c := m.conns[key]
delete(m.conns, key)
// keep lock section small to avoid deadlock with conn lock
m.mu.Unlock()
if c == nil {
return
}
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
addresses := c.getAddresses()
for _, addr := range addresses {
if connList, ok := m.addressMap[addr]; ok {
var newList []*udpMuxedConn
for _, conn := range connList {
if conn.params.Key != key {
newList = append(newList, conn)
}
}
m.addressMap[addr] = newList
}
}
}
func (m *UDPMuxDefault) writeTo(buf []byte, raddr net.Addr) (n int, err error) {
return m.params.UDPConn.WriteTo(buf, raddr)
}
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) {
if m.IsClosed() {
return
}
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
existing, ok := m.addressMap[addr]
if !ok {
existing = []*udpMuxedConn{}
}
existing = append(existing, conn)
m.addressMap[addr] = existing
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
}
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
c := newUDPMuxedConn(&udpMuxedConnParams{
Mux: m,
Key: key,
AddrPool: m.pool,
LocalAddr: m.LocalAddr(),
Logger: m.params.Logger,
})
log.Debugf("ICE: created muxed connection %s for %s", c.LocalAddr().String(), key)
return c
}
type bufferHolder struct {
buffer []byte
}
func newBufferHolder(size int) *bufferHolder {
return &bufferHolder{
buffer: make([]byte, size),
}
}

View File

@@ -1,235 +0,0 @@
package iface
import (
"errors"
"fmt"
log "github.com/sirupsen/logrus"
"net"
"time"
"github.com/pion/logging"
"github.com/pion/stun"
)
// UniversalUDPMux allows multiple connections to go over a single UDP port for
// host, server reflexive and relayed candidates.
// Actual connection muxing is happening in the UDPMux.
type UniversalUDPMux interface {
UDPMux
GetXORMappedAddr(stunAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error)
GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error)
GetConnForURL(ufrag string, url string) (net.PacketConn, error)
}
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn overriding ReadFrom.
// It the passes packets to the UDPMux that does the actual connection muxing.
type UniversalUDPMuxDefault struct {
*UDPMuxDefault
params UniversalUDPMuxParams
// since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
// stun.XORMappedAddress indexed by the STUN server addr
xorMappedMap map[string]*xorMapped
}
// UniversalUDPMuxParams are parameters for UniversalUDPMux server reflexive.
type UniversalUDPMuxParams struct {
Logger logging.LeveledLogger
UDPConn net.PacketConn
XORMappedAddrCacheTTL time.Duration
}
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}
if params.XORMappedAddrCacheTTL == 0 {
params.XORMappedAddrCacheTTL = time.Second * 25
}
m := &UniversalUDPMuxDefault{
params: params,
xorMappedMap: make(map[string]*xorMapped),
}
// embed UDPMux
udpMuxParams := UDPMuxParams{
Logger: params.Logger,
UDPConn: m.params.UDPConn,
}
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
return m
}
// GetRelayedAddr creates relayed connection to the given TURN service and returns the relayed addr.
// Not implemented yet.
func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error) {
return nil, errors.New("not implemented yet")
}
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server.
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url))
}
func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
// message about this err will be logged in the UDPMux
return nil
}
if m.isXORMappedResponse(msg, udpAddr.String()) {
err := m.handleXORMappedResponse(udpAddr, msg)
if err != nil {
log.Debugf("%w: %v", errors.New("failed to get XOR-MAPPED-ADDRESS response"), err)
return nil
}
return nil
}
return m.UDPMuxDefault.HandleSTUNMessage(msg, addr)
}
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.
func (m *UniversalUDPMuxDefault) isXORMappedResponse(msg *stun.Message, stunAddr string) bool {
m.mu.Lock()
defer m.mu.Unlock()
// check first if it is a STUN server address because remote peer can also send similar messages but as a BindingSuccess
_, ok := m.xorMappedMap[stunAddr]
_, err := msg.Get(stun.AttrXORMappedAddress)
return err == nil && ok
}
// handleXORMappedResponse parses response from the STUN server, extracts XORMappedAddress attribute
// and set the mapped address for the server
func (m *UniversalUDPMuxDefault) handleXORMappedResponse(stunAddr *net.UDPAddr, msg *stun.Message) error {
m.mu.Lock()
defer m.mu.Unlock()
mappedAddr, ok := m.xorMappedMap[stunAddr.String()]
if !ok {
return errors.New("no address mapping")
}
var addr stun.XORMappedAddress
if err := addr.GetFrom(msg); err != nil {
return err
}
m.xorMappedMap[stunAddr.String()] = mappedAddr
mappedAddr.SetAddr(&addr)
return nil
}
// GetXORMappedAddr returns *stun.XORMappedAddress if already present for a given STUN server.
// Makes a STUN binding request to discover mapped address otherwise.
// Blocks until the stun.XORMappedAddress has been discovered or deadline.
// Method is safe for concurrent use.
func (m *UniversalUDPMuxDefault) GetXORMappedAddr(serverAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) {
m.mu.Lock()
mappedAddr, ok := m.xorMappedMap[serverAddr.String()]
// if we already have a mapping for this STUN server (address already received)
// and if it is not too old we return it without making a new request to STUN server
if ok {
if mappedAddr.expired() {
mappedAddr.closeWaiters()
delete(m.xorMappedMap, serverAddr.String())
ok = false
} else if mappedAddr.pending() {
ok = false
}
}
m.mu.Unlock()
if ok {
return mappedAddr.addr, nil
}
// otherwise, make a STUN request to discover the address
// or wait for already sent request to complete
waitAddrReceived, err := m.sendStun(serverAddr)
if err != nil {
return nil, errors.New("failed to send STUN packet")
}
// block until response was handled by the connWorker routine and XORMappedAddress was updated
select {
case <-waitAddrReceived:
// when channel closed, addr was obtained
m.mu.Lock()
mappedAddr := *m.xorMappedMap[serverAddr.String()]
m.mu.Unlock()
if mappedAddr.addr == nil {
return nil, errors.New("no address mapping")
}
return mappedAddr.addr, nil
case <-time.After(deadline):
return nil, errors.New("timeout while waiting for XORMappedAddr")
}
}
// sendStun sends a STUN request via UDP conn.
//
// The returned channel is closed when the STUN response has been received.
// Method is safe for concurrent use.
func (m *UniversalUDPMuxDefault) sendStun(serverAddr net.Addr) (chan struct{}, error) {
m.mu.Lock()
defer m.mu.Unlock()
// if record present in the map, we already sent a STUN request,
// just wait when waitAddrReceived will be closed
addrMap, ok := m.xorMappedMap[serverAddr.String()]
if !ok {
addrMap = &xorMapped{
expiresAt: time.Now().Add(m.params.XORMappedAddrCacheTTL),
waitAddrReceived: make(chan struct{}),
}
m.xorMappedMap[serverAddr.String()] = addrMap
}
req, err := stun.Build(stun.BindingRequest, stun.TransactionID)
if err != nil {
return nil, err
}
if _, err = m.params.UDPConn.WriteTo(req.Raw, serverAddr); err != nil {
return nil, err
}
return addrMap.waitAddrReceived, nil
}
type xorMapped struct {
addr *stun.XORMappedAddress
waitAddrReceived chan struct{}
expiresAt time.Time
}
func (a *xorMapped) closeWaiters() {
select {
case <-a.waitAddrReceived:
// notify was close, ok, that means we received duplicate response
// just exit
break
default:
// notify tha twe have a new addr
close(a.waitAddrReceived)
}
}
func (a *xorMapped) pending() bool {
return a.addr == nil
}
func (a *xorMapped) expired() bool {
return a.expiresAt.Before(time.Now())
}
func (a *xorMapped) SetAddr(addr *stun.XORMappedAddress) {
a.addr = addr
a.closeWaiters()
}

View File

@@ -1,246 +0,0 @@
package iface
import (
"encoding/binary"
"io"
"net"
"sync"
"time"
"github.com/pion/logging"
"github.com/pion/transport/packetio"
)
type udpMuxedConnParams struct {
Mux *UDPMuxDefault
AddrPool *sync.Pool
Key string
LocalAddr net.Addr
Logger logging.LeveledLogger
}
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
type udpMuxedConn struct {
params *udpMuxedConnParams
// remote addresses that we have sent to on this conn
addresses []string
// channel holding incoming packets
buffer *packetio.Buffer
closedChan chan struct{}
closeOnce sync.Once
mu sync.Mutex
}
func newUDPMuxedConn(params *udpMuxedConnParams) *udpMuxedConn {
p := &udpMuxedConn{
params: params,
buffer: packetio.NewBuffer(),
closedChan: make(chan struct{}),
}
return p
}
func (c *udpMuxedConn) ReadFrom(b []byte) (n int, raddr net.Addr, err error) {
buf := c.params.AddrPool.Get().(*bufferHolder)
defer c.params.AddrPool.Put(buf)
// read address
total, err := c.buffer.Read(buf.buffer)
if err != nil {
return 0, nil, err
}
dataLen := int(binary.LittleEndian.Uint16(buf.buffer[:2]))
if dataLen > total || dataLen > len(b) {
return 0, nil, io.ErrShortBuffer
}
// read data and then address
offset := 2
copy(b, buf.buffer[offset:offset+dataLen])
offset += dataLen
// read address len & decode address
addrLen := int(binary.LittleEndian.Uint16(buf.buffer[offset : offset+2]))
offset += 2
if raddr, err = decodeUDPAddr(buf.buffer[offset : offset+addrLen]); err != nil {
return 0, nil, err
}
return dataLen, raddr, nil
}
func (c *udpMuxedConn) WriteTo(buf []byte, raddr net.Addr) (n int, err error) {
if c.isClosed() {
return 0, io.ErrClosedPipe
}
// each time we write to a new address, we'll register it with the mux
addr := raddr.String()
if !c.containsAddress(addr) {
c.addAddress(addr)
}
return c.params.Mux.writeTo(buf, raddr)
}
func (c *udpMuxedConn) LocalAddr() net.Addr {
return c.params.LocalAddr
}
func (c *udpMuxedConn) SetDeadline(tm time.Time) error {
return nil
}
func (c *udpMuxedConn) SetReadDeadline(tm time.Time) error {
return nil
}
func (c *udpMuxedConn) SetWriteDeadline(tm time.Time) error {
return nil
}
func (c *udpMuxedConn) CloseChannel() <-chan struct{} {
return c.closedChan
}
func (c *udpMuxedConn) Close() error {
var err error
c.closeOnce.Do(func() {
err = c.buffer.Close()
close(c.closedChan)
})
c.mu.Lock()
defer c.mu.Unlock()
c.addresses = nil
return err
}
func (c *udpMuxedConn) isClosed() bool {
select {
case <-c.closedChan:
return true
default:
return false
}
}
func (c *udpMuxedConn) getAddresses() []string {
c.mu.Lock()
defer c.mu.Unlock()
addresses := make([]string, len(c.addresses))
copy(addresses, c.addresses)
return addresses
}
func (c *udpMuxedConn) addAddress(addr string) {
c.mu.Lock()
c.addresses = append(c.addresses, addr)
c.mu.Unlock()
// map it on mux
c.params.Mux.registerConnForAddress(c, addr)
}
func (c *udpMuxedConn) removeAddress(addr string) {
c.mu.Lock()
defer c.mu.Unlock()
newAddresses := make([]string, 0, len(c.addresses))
for _, a := range c.addresses {
if a != addr {
newAddresses = append(newAddresses, a)
}
}
c.addresses = newAddresses
}
func (c *udpMuxedConn) containsAddress(addr string) bool {
c.mu.Lock()
defer c.mu.Unlock()
for _, a := range c.addresses {
if addr == a {
return true
}
}
return false
}
func (c *udpMuxedConn) writePacket(data []byte, addr *net.UDPAddr) error {
// write two packets, address and data
buf := c.params.AddrPool.Get().(*bufferHolder)
defer c.params.AddrPool.Put(buf)
// format of buffer | data len | data bytes | addr len | addr bytes |
if len(buf.buffer) < len(data)+maxAddrSize {
return io.ErrShortBuffer
}
// data len
binary.LittleEndian.PutUint16(buf.buffer, uint16(len(data)))
offset := 2
// data
copy(buf.buffer[offset:], data)
offset += len(data)
// write address first, leaving room for its length
n, err := encodeUDPAddr(addr, buf.buffer[offset+2:])
if err != nil {
return nil
}
total := offset + n + 2
// address len
binary.LittleEndian.PutUint16(buf.buffer[offset:], uint16(n))
if _, err := c.buffer.Write(buf.buffer[:total]); err != nil {
return err
}
return nil
}
func encodeUDPAddr(addr *net.UDPAddr, buf []byte) (int, error) {
ipdata, err := addr.IP.MarshalText()
if err != nil {
return 0, err
}
total := 2 + len(ipdata) + 2 + len(addr.Zone)
if total > len(buf) {
return 0, io.ErrShortBuffer
}
binary.LittleEndian.PutUint16(buf, uint16(len(ipdata)))
offset := 2
n := copy(buf[offset:], ipdata)
offset += n
binary.LittleEndian.PutUint16(buf[offset:], uint16(addr.Port))
offset += 2
copy(buf[offset:], addr.Zone)
return total, nil
}
func decodeUDPAddr(buf []byte) (*net.UDPAddr, error) {
addr := net.UDPAddr{}
offset := 0
ipLen := int(binary.LittleEndian.Uint16(buf[:2]))
offset += 2
// basic bounds checking
if ipLen+offset > len(buf) {
return nil, io.ErrShortBuffer
}
if err := addr.IP.UnmarshalText(buf[offset : offset+ipLen]); err != nil {
return nil, err
}
offset += ipLen
addr.Port = int(binary.LittleEndian.Uint16(buf[offset : offset+2]))
offset += 2
zone := make([]byte, len(buf[offset:]))
copy(zone, buf[offset:])
addr.Zone = string(zone)
return &addr, nil
}

View File

@@ -109,7 +109,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
return err
}
cancel, stream, err := c.connectToStream(*serverPubKey)
stream, err := c.connectToStream(*serverPubKey)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
@@ -117,7 +117,6 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
}
return err
}
defer cancel()
log.Infof("connected to the Management Service stream")
@@ -146,7 +145,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
return nil
}
func (c *GrpcClient) connectToStream(serverPubKey wgtypes.Key) (context.CancelFunc, proto.ManagementService_SyncClient, error) {
func (c *GrpcClient) connectToStream(serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{}
myPrivateKey := c.key
@@ -155,16 +154,11 @@ func (c *GrpcClient) connectToStream(serverPubKey wgtypes.Key) (context.CancelFu
encryptedReq, err := encryption.EncryptMessage(serverPubKey, myPrivateKey, req)
if err != nil {
log.Errorf("failed encrypting message: %s", err)
return nil, nil, err
return nil, err
}
ctx, cancel := context.WithCancel(c.ctx)
syncReq := &proto.EncryptedMessage{WgPubKey: myPublicKey.String(), Body: encryptedReq}
sync, err := c.realClient.Sync(ctx, syncReq)
if err != nil {
cancel()
return nil, nil, err
}
return cancel, sync, nil
return c.realClient.Sync(c.ctx, syncReq)
}
func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error {

View File

@@ -31,14 +31,15 @@ const (
type AccountManager interface {
GetOrCreateAccountByUser(userId, domain string) (*Account, error)
GetAccountByUser(userId string) (*Account, error)
AddSetupKey(
CreateSetupKey(
accountId string,
keyName string,
keyType SetupKeyType,
expiresIn time.Duration,
autoGroups []string,
) (*SetupKey, error)
RevokeSetupKey(accountId string, keyId string) (*SetupKey, error)
RenameSetupKey(accountId string, keyId string, newName string) (*SetupKey, error)
SaveSetupKey(accountID string, key *SetupKey) (*SetupKey, error)
GetSetupKey(accountID, keyID string) (*SetupKey, error)
GetAccountById(accountId string) (*Account, error)
GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error)
GetAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error)
@@ -75,6 +76,7 @@ type AccountManager interface {
UpdateRoute(accountID string, routeID string, operations []RouteUpdateOperation) (*route.Route, error)
DeleteRoute(accountID, routeID string) error
ListRoutes(accountID string) ([]*route.Route, error)
ListSetupKeys(accountID string) ([]*SetupKey, error)
}
type DefaultAccountManager struct {
@@ -244,93 +246,6 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
return nil
}
// AddSetupKey generates a new setup key with a given name and type, and adds it to the specified account
func (am *DefaultAccountManager) AddSetupKey(
accountId string,
keyName string,
keyType SetupKeyType,
expiresIn time.Duration,
) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
keyDuration := DefaultSetupKeyDuration
if expiresIn != 0 {
keyDuration = expiresIn
}
account, err := am.Store.GetAccount(accountId)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
setupKey := GenerateSetupKey(keyName, keyType, keyDuration)
account.SetupKeys[setupKey.Key] = setupKey
err = am.Store.SaveAccount(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed adding account key")
}
return setupKey, nil
}
// RevokeSetupKey marks SetupKey as revoked - becomes not valid anymore
func (am *DefaultAccountManager) RevokeSetupKey(accountId string, keyId string) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountId)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
setupKey := getAccountSetupKeyById(account, keyId)
if setupKey == nil {
return nil, status.Errorf(codes.NotFound, "unknown setupKey %s", keyId)
}
keyCopy := setupKey.Copy()
keyCopy.Revoked = true
account.SetupKeys[keyCopy.Key] = keyCopy
err = am.Store.SaveAccount(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed adding account key")
}
return keyCopy, nil
}
// RenameSetupKey renames existing setup key of the specified account.
func (am *DefaultAccountManager) RenameSetupKey(
accountId string,
keyId string,
newName string,
) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountId)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
setupKey := getAccountSetupKeyById(account, keyId)
if setupKey == nil {
return nil, status.Errorf(codes.NotFound, "unknown setupKey %s", keyId)
}
keyCopy := setupKey.Copy()
keyCopy.Name = newName
account.SetupKeys[keyCopy.Key] = keyCopy
err = am.Store.SaveAccount(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed adding account key")
}
return keyCopy, nil
}
// GetAccountById returns an existing account using its ID or error (NotFound) if doesn't exist
func (am *DefaultAccountManager) GetAccountById(accountId string) (*Account, error) {
am.mux.Lock()
@@ -504,7 +419,6 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(
// handleExistingUserAccount handles existing User accounts and update its domain attributes.
//
//
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
@@ -688,7 +602,7 @@ func newAccountWithId(accountId, userId, domain string) *Account {
setupKeys := make(map[string]*SetupKey)
defaultKey := GenerateDefaultSetupKey()
oneOffKey := GenerateSetupKey("One-off key", SetupKeyOneOff, DefaultSetupKeyDuration)
oneOffKey := GenerateSetupKey("One-off key", SetupKeyOneOff, DefaultSetupKeyDuration, []string{})
setupKeys[defaultKey.Key] = defaultKey
setupKeys[oneOffKey.Key] = oneOffKey
network := NewNetwork()
@@ -713,15 +627,6 @@ func newAccountWithId(accountId, userId, domain string) *Account {
return acc
}
func getAccountSetupKeyById(acc *Account, keyId string) *SetupKey {
for _, k := range acc.SetupKeys {
if keyId == k.Id {
return k
}
}
return nil
}
func getAccountSetupKeyByKey(acc *Account, key string) *SetupKey {
for _, k := range acc.SetupKeys {
if key == k.Key {

View File

@@ -134,6 +134,15 @@ components:
state:
description: Setup key status, "valid", "overused","expired" or "revoked"
type: string
auto_groups:
description: Setup key groups to auto-assign to peers registered with this key
type: array
items:
$ref: '#/components/schemas/GroupMinimum'
updated_at:
description: Setup key last update date
type: string
format: date-time
required:
- id
- key
@@ -145,6 +154,8 @@ components:
- used_times
- last_used
- state
- auto_groups
- updated_at
SetupKeyRequest:
type: object
properties:
@@ -160,11 +171,17 @@ components:
revoked:
description: Setup key revocation status
type: boolean
auto_groups:
description: Setup key groups to auto-assign to peers registered with this key
type: array
items:
type: string
required:
- name
- type
- expires_in
- revoked
- auto_groups
GroupMinimum:
type: object
properties:

View File

@@ -299,6 +299,9 @@ type RulePatchOperationPath string
// SetupKey defines model for SetupKey.
type SetupKey struct {
// Setup key groups to auto-assign to peers registered with this key
AutoGroups []GroupMinimum `json:"auto_groups"`
// Setup Key expiration date
Expires time.Time `json:"expires"`
@@ -323,6 +326,9 @@ type SetupKey struct {
// Setup key type, one-off for single time usage and reusable
Type string `json:"type"`
// Setup key last update date
UpdatedAt time.Time `json:"updated_at"`
// Usage count of setup key
UsedTimes int `json:"used_times"`
@@ -332,6 +338,9 @@ type SetupKey struct {
// SetupKeyRequest defines model for SetupKeyRequest.
type SetupKeyRequest struct {
// Setup key groups to auto-assign to peers registered with this key
AutoGroups []string `json:"auto_groups"`
// Expiration time in seconds
ExpiresIn int `json:"expires_in"`

View File

@@ -344,6 +344,14 @@ func peerIPsToKeys(account *server.Account, peerIPs *[]string) []string {
return mappedPeerKeys
}
func toGroupMinimumResponse(group *server.Group) *api.GroupMinimum {
return &api.GroupMinimum{
Id: group.ID,
Name: group.Name,
PeersCount: len(group.Peers),
}
}
func toGroupResponse(account *server.Account, group *server.Group) *api.Group {
cache := make(map[string]api.PeerMinimum)
gr := api.Group{

View File

@@ -39,12 +39,11 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
apiHandler.HandleFunc("/api/peers/{id}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
apiHandler.HandleFunc("/api/users", userHandler.GetUsers).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys", keysHandler.GetKeys).Methods("GET", "POST", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys/{id}", keysHandler.HandleKey).Methods("GET", "PUT", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys", keysHandler.GetKeys).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys/{id}", keysHandler.HandleKey).
Methods("GET", "PUT", "DELETE", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys", keysHandler.GetAllSetupKeysHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys", keysHandler.CreateSetupKeyHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys/{id}", keysHandler.GetSetupKeyHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys/{id}", keysHandler.UpdateSetupKeyHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/rules", rulesHandler.GetAllRulesHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/rules", rulesHandler.CreateRuleHandler).Methods("POST", "OPTIONS")

View File

@@ -348,6 +348,11 @@ func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
err = h.accountManager.DeleteRoute(account.Id, routeID)
if err != nil {
errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, fmt.Sprintf("route %s not found under account %s", routeID, account.Id), http.StatusNotFound)
return
}
log.Errorf("failed delete route %s under account %s %v", routeID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return

View File

@@ -78,7 +78,10 @@ func initRoutesTestData() *Routes {
SaveRouteFunc: func(_ string, _ *route.Route) error {
return nil
},
DeleteRouteFunc: func(_ string, _ string) error {
DeleteRouteFunc: func(_ string, peerIP string) error {
if peerIP != existingRouteID {
return status.Errorf(codes.NotFound, "Peer with ID %s not found", peerIP)
}
return nil
},
GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) {
@@ -155,7 +158,7 @@ func TestRoutesHandlers(t *testing.T) {
{
name: "Get Not Existing Route",
requestType: http.MethodGet,
requestPath: "/api/rules/" + notFoundRouteID,
requestPath: "/api/routes/" + notFoundRouteID,
expectedStatus: http.StatusNotFound,
},
{
@@ -168,7 +171,7 @@ func TestRoutesHandlers(t *testing.T) {
{
name: "Delete Not Existing Route",
requestType: http.MethodDelete,
requestPath: "/api/rules/" + notFoundRouteID,
requestPath: "/api/routes/" + notFoundRouteID,
expectedStatus: http.StatusNotFound,
},
{

View File

@@ -2,6 +2,7 @@ package http
import (
"encoding/json"
"fmt"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
@@ -28,54 +29,17 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authAudience stri
}
}
func (h *SetupKeys) updateKey(accountId string, keyId string, w http.ResponseWriter, r *http.Request) {
req := &api.PutApiSetupKeysIdJSONRequestBody{}
err := json.NewDecoder(r.Body).Decode(&req)
// CreateSetupKeyHandler is a POST requests that creates a new SetupKey
func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
var key *server.SetupKey
if req.Revoked {
//handle only if being revoked, don't allow to enable key again for now
key, err = h.accountManager.RevokeSetupKey(accountId, keyId)
if err != nil {
http.Error(w, "failed revoking key", http.StatusInternalServerError)
return
}
}
if len(req.Name) != 0 {
key, err = h.accountManager.RenameSetupKey(accountId, keyId, req.Name)
if err != nil {
http.Error(w, "failed renaming key", http.StatusInternalServerError)
return
}
}
if key != nil {
writeSuccess(w, key)
}
}
func (h *SetupKeys) getKey(accountId string, keyId string, w http.ResponseWriter, r *http.Request) {
account, err := h.accountManager.GetAccountById(accountId)
if err != nil {
http.Error(w, "account doesn't exist", http.StatusInternalServerError)
return
}
for _, key := range account.SetupKeys {
if key.Id == keyId {
writeSuccess(w, key)
return
}
}
http.Error(w, "setup key not found", http.StatusNotFound)
}
func (h *SetupKeys) createKey(accountId string, w http.ResponseWriter, r *http.Request) {
req := &api.PostApiSetupKeysJSONRequestBody{}
err := json.NewDecoder(r.Body).Decode(&req)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
@@ -95,7 +59,13 @@ func (h *SetupKeys) createKey(accountId string, w http.ResponseWriter, r *http.R
expiresIn := time.Duration(req.ExpiresIn) * time.Second
setupKey, err := h.accountManager.AddSetupKey(accountId, req.Name, server.SetupKeyType(req.Type), expiresIn)
if req.AutoGroups == nil {
req.AutoGroups = []string{}
}
// newExpiresIn := time.Duration(req.ExpiresIn) * time.Second
// newKey.ExpiresAt = time.Now().Add(newExpiresIn)
setupKey, err := h.accountManager.CreateSetupKey(account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
req.AutoGroups)
if err != nil {
errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.NotFound {
@@ -109,7 +79,8 @@ func (h *SetupKeys) createKey(accountId string, w http.ResponseWriter, r *http.R
writeSuccess(w, setupKey)
}
func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) {
// GetSetupKeyHandler is a GET request to get a SetupKey by ID
func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
@@ -118,25 +89,84 @@ func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) {
}
vars := mux.Vars(r)
keyId := vars["id"]
if len(keyId) == 0 {
keyID := vars["id"]
if len(keyID) == 0 {
http.Error(w, "invalid key Id", http.StatusBadRequest)
return
}
switch r.Method {
case http.MethodPut:
h.updateKey(account.Id, keyId, w, r)
key, err := h.accountManager.GetSetupKey(account.Id, keyID)
if err != nil {
errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, fmt.Sprintf("setup key %s not found under account %s", keyID, account.Id), http.StatusNotFound)
return
}
log.Errorf("failed getting setup key %s under account %s %v", keyID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
case http.MethodGet:
h.getKey(account.Id, keyId, w, r)
return
default:
http.Error(w, "", http.StatusNotFound)
}
writeSuccess(w, key)
}
func (h *SetupKeys) GetKeys(w http.ResponseWriter, r *http.Request) {
// UpdateSetupKeyHandler is a PUT request to update server.SetupKey
func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
vars := mux.Vars(r)
keyID := vars["id"]
if len(keyID) == 0 {
http.Error(w, "invalid key Id", http.StatusBadRequest)
return
}
req := &api.PutApiSetupKeysIdJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if req.Name == "" {
http.Error(w, fmt.Sprintf("setup key name field is invalid: %s", req.Name), http.StatusBadRequest)
return
}
if req.AutoGroups == nil {
http.Error(w, fmt.Sprintf("setup key AutoGroups field is invalid: %s", req.AutoGroups), http.StatusBadRequest)
return
}
newKey := &server.SetupKey{}
newKey.AutoGroups = req.AutoGroups
newKey.Revoked = req.Revoked
newKey.Name = req.Name
newKey.Id = keyID
newKey, err = h.accountManager.SaveSetupKey(account.Id, newKey)
if err != nil {
if e, ok := status.FromError(err); ok {
switch e.Code() {
case codes.NotFound:
http.Error(w, fmt.Sprintf("couldn't find setup key for ID %s", keyID), http.StatusNotFound)
default:
http.Error(w, "failed updating setup key", http.StatusInternalServerError)
}
}
return
}
writeSuccess(w, newKey)
}
// GetAllSetupKeysHandler is a GET request that returns a list of SetupKey
func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
@@ -145,41 +175,40 @@ func (h *SetupKeys) GetKeys(w http.ResponseWriter, r *http.Request) {
return
}
switch r.Method {
case http.MethodPost:
h.createKey(account.Id, w, r)
setupKeys, err := h.accountManager.ListSetupKeys(account.Id)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
case http.MethodGet:
w.WriteHeader(200)
w.Header().Set("Content-Type", "application/json")
respBody := []*api.SetupKey{}
for _, key := range account.SetupKeys {
respBody = append(respBody, toResponseBody(key))
}
err = json.NewEncoder(w).Encode(respBody)
if err != nil {
log.Errorf("failed encoding account peers %s: %v", account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
default:
http.Error(w, "", http.StatusNotFound)
}
groups, err := h.accountManager.ListGroups(account.Id)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
apiSetupKeys := make([]*api.SetupKey, 0)
for _, key := range setupKeys {
apiSetupKeys = append(apiSetupKeys, toResponseBody(groups, key))
}
writeJSONObject(w, apiSetupKeys)
}
func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
w.WriteHeader(200)
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(toResponseBody(key))
err := json.NewEncoder(w).Encode(toResponseBody(nil, key))
if err != nil {
http.Error(w, "failed handling request", http.StatusInternalServerError)
return
}
}
func toResponseBody(key *server.SetupKey) *api.SetupKey {
// toResponseBody takes a list of all groups, a key, finds groups that belong to the key and add them to the response
func toResponseBody(groups []*server.Group, key *server.SetupKey) *api.SetupKey {
var state string
if key.IsExpired() {
state = "expired"
@@ -190,16 +219,30 @@ func toResponseBody(key *server.SetupKey) *api.SetupKey {
} else {
state = "valid"
}
// should be instantiated like that to ensure if empty, we return an empty array to the client
autoGroups := []api.GroupMinimum{}
for _, group := range groups {
for _, keyGroup := range key.AutoGroups {
if group.ID == keyGroup {
autoGroups = append(autoGroups, *toGroupMinimumResponse(group))
break
}
}
}
return &api.SetupKey{
Id: key.Id,
Key: key.Key,
Name: key.Name,
Expires: key.ExpiresAt,
Type: string(key.Type),
Valid: key.IsValid(),
Revoked: key.Revoked,
UsedTimes: key.UsedTimes,
LastUsed: key.LastUsed,
State: state,
Id: key.Id,
Key: key.Key,
Name: key.Name,
Expires: key.ExpiresAt,
Type: string(key.Type),
Valid: key.IsValid(),
Revoked: key.Revoked,
UsedTimes: key.UsedTimes,
LastUsed: key.LastUsed,
State: state,
AutoGroups: autoGroups,
UpdatedAt: key.UpdatedAt,
}
}

View File

@@ -0,0 +1,222 @@
package http
import (
"bytes"
"encoding/json"
"fmt"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server"
)
const (
existingSetupKeyID = "existingSetupKeyID"
newSetupKeyName = "New Setup Key"
updatedSetupKeyName = "KKKey"
notFoundSetupKeyID = "notFoundSetupKeyID"
)
func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey) *SetupKeys {
return &SetupKeys{
accountManager: &mock_server.MockAccountManager{
GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
return &server.Account{
Id: testAccountID,
Domain: "hotmail.com",
SetupKeys: map[string]*server.SetupKey{
defaultKey.Key: defaultKey,
},
Groups: map[string]*server.Group{
"group-1": {ID: "group-1", Peers: []string{"A", "B"}},
"id-all": {ID: "id-all", Name: "All"}},
}, nil
},
CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string) (*server.SetupKey, error) {
if keyName == newKey.Name || typ != newKey.Type {
return newKey, nil
}
return nil, fmt.Errorf("failed creating setup key")
},
GetSetupKeyFunc: func(accountID string, keyID string) (*server.SetupKey, error) {
switch keyID {
case defaultKey.Id:
return defaultKey, nil
case newKey.Id:
return newKey, nil
default:
return nil, status.Errorf(codes.NotFound, "key %s not found", keyID)
}
},
SaveSetupKeyFunc: func(accountID string, key *server.SetupKey) (*server.SetupKey, error) {
if key.Id == updatedSetupKey.Id {
return updatedSetupKey, nil
}
return nil, status.Errorf(codes.NotFound, "key %s not found", key.Id)
},
ListSetupKeysFunc: func(accountID string) ([]*server.SetupKey, error) {
return []*server.SetupKey{defaultKey}, nil
},
},
authAudience: "",
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudience string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: testAccountID,
}
},
},
}
}
func TestSetupKeysHandlers(t *testing.T) {
defaultSetupKey := server.GenerateDefaultSetupKey()
defaultSetupKey.Id = existingSetupKeyID
newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"})
updatedDefaultSetupKey := defaultSetupKey.Copy()
updatedDefaultSetupKey.AutoGroups = []string{"group-1"}
updatedDefaultSetupKey.Name = updatedSetupKeyName
updatedDefaultSetupKey.Revoked = true
tt := []struct {
name string
requestType string
requestPath string
requestBody io.Reader
expectedStatus int
expectedBody bool
expectedSetupKey *api.SetupKey
expectedSetupKeys []*api.SetupKey
}{
{
name: "Get Setup Keys",
requestType: http.MethodGet,
requestPath: "/api/setup-keys",
expectedStatus: http.StatusOK,
expectedBody: true,
expectedSetupKeys: []*api.SetupKey{toResponseBody(nil, defaultSetupKey)},
},
{
name: "Get Existing Setup Key",
requestType: http.MethodGet,
requestPath: "/api/setup-keys/" + existingSetupKeyID,
expectedStatus: http.StatusOK,
expectedBody: true,
expectedSetupKey: toResponseBody(nil, defaultSetupKey),
},
{
name: "Get Not Existing Setup Key",
requestType: http.MethodGet,
requestPath: "/api/setup-keys/" + notFoundSetupKeyID,
expectedStatus: http.StatusNotFound,
expectedBody: false,
},
{
name: "Create Setup Key",
requestType: http.MethodPost,
requestPath: "/api/setup-keys",
requestBody: bytes.NewBuffer(
[]byte(fmt.Sprintf("{\"name\":\"%s\",\"type\":\"%s\"}", newSetupKey.Name, newSetupKey.Type))),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedSetupKey: toResponseBody(nil, newSetupKey),
},
{
name: "Update Setup Key",
requestType: http.MethodPut,
requestPath: "/api/setup-keys/" + defaultSetupKey.Id,
requestBody: bytes.NewBuffer(
[]byte(fmt.Sprintf("{\"name\":\"%s\",\"auto_groups\":[\"%s\"], \"revoked\":%v}",
updatedDefaultSetupKey.Type,
updatedDefaultSetupKey.AutoGroups[0],
updatedDefaultSetupKey.Revoked,
))),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedSetupKey: toResponseBody(nil, updatedDefaultSetupKey),
},
}
handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/setup-keys", handler.GetAllSetupKeysHandler).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys", handler.CreateSetupKeyHandler).Methods("POST", "OPTIONS")
router.HandleFunc("/api/setup-keys/{id}", handler.GetSetupKeyHandler).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys/{id}", handler.UpdateSetupKeyHandler).Methods("PUT", "OPTIONS")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
status, tc.expectedStatus, string(content))
return
}
if !tc.expectedBody {
return
}
if tc.expectedSetupKey != nil {
got := &api.SetupKey{}
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assertKeys(t, got, tc.expectedSetupKey)
return
}
if len(tc.expectedSetupKeys) > 0 {
var got []*api.SetupKey
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assertKeys(t, got[0], tc.expectedSetupKeys[0])
return
}
})
}
}
func assertKeys(t *testing.T, got *api.SetupKey, expected *api.SetupKey) {
// this comparison is done manually because when converting to JSON dates formatted differently
// assert.Equal(t, got.UpdatedAt, tc.expectedSetupKey.UpdatedAt) //doesn't work
assert.WithinDurationf(t, got.UpdatedAt, expected.UpdatedAt, 0, "")
assert.WithinDurationf(t, got.Expires, expected.Expires, 0, "")
assert.Equal(t, got.Name, expected.Name)
assert.Equal(t, got.Id, expected.Id)
assert.Equal(t, got.Key, expected.Key)
assert.Equal(t, got.Type, expected.Type)
assert.Equal(t, got.UsedTimes, expected.UsedTimes)
assert.Equal(t, got.Revoked, expected.Revoked)
assert.ElementsMatch(t, got.AutoGroups, expected.AutoGroups)
}

View File

@@ -12,9 +12,8 @@ import (
type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
GetAccountByUserFunc func(userId string) (*server.Account, error)
AddSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration) (*server.SetupKey, error)
RevokeSetupKeyFunc func(accountId string, keyId string) (*server.SetupKey, error)
RenameSetupKeyFunc func(accountId string, keyId string, newName string) (*server.SetupKey, error)
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string) (*server.SetupKey, error)
GetSetupKeyFunc func(accountID string, keyID string) (*server.SetupKey, error)
GetAccountByIdFunc func(accountId string) (*server.Account, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
GetAccountWithAuthorizationClaimsFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, error)
@@ -51,6 +50,8 @@ type MockAccountManager struct {
UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error)
DeleteRouteFunc func(accountID, routeID string) error
ListRoutesFunc func(accountID string) ([]*route.Route, error)
SaveSetupKeyFunc func(accountID string, key *server.SetupKey) (*server.SetupKey, error)
ListSetupKeysFunc func(accountID string) ([]*server.SetupKey, error)
}
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
@@ -82,40 +83,18 @@ func (am *MockAccountManager) GetAccountByUser(userId string) (*server.Account,
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUser is not implemented")
}
// AddSetupKey mock implementation of AddSetupKey from server.AccountManager interface
func (am *MockAccountManager) AddSetupKey(
// CreateSetupKey mock implementation of CreateSetupKey from server.AccountManager interface
func (am *MockAccountManager) CreateSetupKey(
accountId string,
keyName string,
keyType server.SetupKeyType,
expiresIn time.Duration,
autoGroups []string,
) (*server.SetupKey, error) {
if am.AddSetupKeyFunc != nil {
return am.AddSetupKeyFunc(accountId, keyName, keyType, expiresIn)
if am.CreateSetupKeyFunc != nil {
return am.CreateSetupKeyFunc(accountId, keyName, keyType, expiresIn, autoGroups)
}
return nil, status.Errorf(codes.Unimplemented, "method AddSetupKey is not implemented")
}
// RevokeSetupKey mock implementation of RevokeSetupKey from server.AccountManager interface
func (am *MockAccountManager) RevokeSetupKey(
accountId string,
keyId string,
) (*server.SetupKey, error) {
if am.RevokeSetupKeyFunc != nil {
return am.RevokeSetupKeyFunc(accountId, keyId)
}
return nil, status.Errorf(codes.Unimplemented, "method RevokeSetupKey is not implemented")
}
// RenameSetupKey mock implementation of RenameSetupKey from server.AccountManager interface
func (am *MockAccountManager) RenameSetupKey(
accountId string,
keyId string,
newName string,
) (*server.SetupKey, error) {
if am.RenameSetupKeyFunc != nil {
return am.RenameSetupKeyFunc(accountId, keyId, newName)
}
return nil, status.Errorf(codes.Unimplemented, "method RenameSetupKey is not implemented")
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
}
// GetAccountById mock implementation of GetAccountById from server.AccountManager interface
@@ -415,3 +394,30 @@ func (am *MockAccountManager) ListRoutes(accountID string) ([]*route.Route, erro
}
return nil, status.Errorf(codes.Unimplemented, "method ListRoutes is not implemented")
}
// SaveSetupKey mocks SaveSetupKey of the AccountManager interface
func (am *MockAccountManager) SaveSetupKey(accountID string, key *server.SetupKey) (*server.SetupKey, error) {
if am.SaveSetupKeyFunc != nil {
return am.SaveSetupKeyFunc(accountID, key)
}
return nil, status.Errorf(codes.Unimplemented, "method SaveSetupKey is not implemented")
}
// GetSetupKey mocks GetSetupKey of the AccountManager interface
func (am *MockAccountManager) GetSetupKey(accountID, keyID string) (*server.SetupKey, error) {
if am.GetSetupKeyFunc != nil {
return am.GetSetupKeyFunc(accountID, keyID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetSetupKey is not implemented")
}
// ListSetupKeys mocks ListSetupKeys of the AccountManager interface
func (am *MockAccountManager) ListSetupKeys(accountID string) ([]*server.SetupKey, error) {
if am.ListSetupKeysFunc != nil {
return am.ListSetupKeysFunc(accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListSetupKeys is not implemented")
}

View File

@@ -1,7 +1,10 @@
package server
import (
"fmt"
"github.com/google/uuid"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"hash/fnv"
"strconv"
"strings"
@@ -18,8 +21,41 @@ const (
DefaultSetupKeyDuration = 24 * 30 * time.Hour
// DefaultSetupKeyName is a default name of the default setup key
DefaultSetupKeyName = "Default key"
// UpdateSetupKeyName indicates a setup key name update operation
UpdateSetupKeyName SetupKeyUpdateOperationType = iota
// UpdateSetupKeyRevoked indicates a setup key revoked filed update operation
UpdateSetupKeyRevoked
// UpdateSetupKeyAutoGroups indicates a setup key auto-assign groups update operation
UpdateSetupKeyAutoGroups
// UpdateSetupKeyExpiresAt indicates a setup key expiration time update operation
UpdateSetupKeyExpiresAt
)
// SetupKeyUpdateOperationType operation type
type SetupKeyUpdateOperationType int
func (t SetupKeyUpdateOperationType) String() string {
switch t {
case UpdateSetupKeyName:
return "UpdateSetupKeyName"
case UpdateSetupKeyRevoked:
return "UpdateSetupKeyRevoked"
case UpdateSetupKeyAutoGroups:
return "UpdateSetupKeyAutoGroups"
case UpdateSetupKeyExpiresAt:
return "UpdateSetupKeyExpiresAt"
default:
return "InvalidOperation"
}
}
// SetupKeyUpdateOperation operation object with type and values to be applied
type SetupKeyUpdateOperation struct {
Type SetupKeyUpdateOperationType
Values []string
}
// SetupKeyType is the type of setup key
type SetupKeyType string
@@ -31,30 +67,38 @@ type SetupKey struct {
Type SetupKeyType
CreatedAt time.Time
ExpiresAt time.Time
UpdatedAt time.Time
// Revoked indicates whether the key was revoked or not (we don't remove them for tracking purposes)
Revoked bool
// UsedTimes indicates how many times the key was used
UsedTimes int
// LastUsed last time the key was used for peer registration
LastUsed time.Time
// AutoGroups is a list of Group IDs that are auto assigned to a Peer when it uses this key to register
AutoGroups []string
}
//Copy copies SetupKey to a new object
// Copy copies SetupKey to a new object
func (key *SetupKey) Copy() *SetupKey {
if key.UpdatedAt.IsZero() {
key.UpdatedAt = key.CreatedAt
}
return &SetupKey{
Id: key.Id,
Key: key.Key,
Name: key.Name,
Type: key.Type,
CreatedAt: key.CreatedAt,
ExpiresAt: key.ExpiresAt,
Revoked: key.Revoked,
UsedTimes: key.UsedTimes,
LastUsed: key.LastUsed,
Id: key.Id,
Key: key.Key,
Name: key.Name,
Type: key.Type,
CreatedAt: key.CreatedAt,
ExpiresAt: key.ExpiresAt,
UpdatedAt: key.UpdatedAt,
Revoked: key.Revoked,
UsedTimes: key.UsedTimes,
LastUsed: key.LastUsed,
AutoGroups: key.AutoGroups,
}
}
//IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now
// IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now
func (key *SetupKey) IncrementUsage() *SetupKey {
c := key.Copy()
c.UsedTimes = c.UsedTimes + 1
@@ -83,24 +127,25 @@ func (key *SetupKey) IsOverUsed() bool {
}
// GenerateSetupKey generates a new setup key
func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration) *SetupKey {
func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoGroups []string) *SetupKey {
key := strings.ToUpper(uuid.New().String())
createdAt := time.Now()
return &SetupKey{
Id: strconv.Itoa(int(Hash(key))),
Key: key,
Name: name,
Type: t,
CreatedAt: createdAt,
ExpiresAt: createdAt.Add(validFor),
Revoked: false,
UsedTimes: 0,
Id: strconv.Itoa(int(Hash(key))),
Key: key,
Name: name,
Type: t,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(validFor),
UpdatedAt: time.Now(),
Revoked: false,
UsedTimes: 0,
AutoGroups: autoGroups,
}
}
// GenerateDefaultSetupKey generates a default setup key
func GenerateDefaultSetupKey() *SetupKey {
return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration)
return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{})
}
func Hash(s string) uint32 {
@@ -111,3 +156,127 @@ func Hash(s string) uint32 {
}
return h.Sum32()
}
// CreateSetupKey generates a new setup key with a given name, type, list of groups IDs to auto-assign to peers registered with this key,
// and adds it to the specified account. A list of autoGroups IDs can be empty.
func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string, keyType SetupKeyType,
expiresIn time.Duration, autoGroups []string) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
keyDuration := DefaultSetupKeyDuration
if expiresIn != 0 {
keyDuration = expiresIn
}
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
for _, group := range autoGroups {
if _, ok := account.Groups[group]; !ok {
return nil, fmt.Errorf("group %s doesn't exist", group)
}
}
setupKey := GenerateSetupKey(keyName, keyType, keyDuration, autoGroups)
account.SetupKeys[setupKey.Key] = setupKey
err = am.Store.SaveAccount(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed adding account key")
}
return setupKey, nil
}
// SaveSetupKey saves the provided SetupKey to the database overriding the existing one.
// Due to the unique nature of a SetupKey certain properties must not be overwritten
// (e.g. the key itself, creation date, ID, etc).
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *SetupKey) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
if keyToSave == nil {
return nil, status.Errorf(codes.InvalidArgument, "provided setup key to update is nil")
}
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
var oldKey *SetupKey
for _, key := range account.SetupKeys {
if key.Id == keyToSave.Id {
oldKey = key.Copy()
break
}
}
if oldKey == nil {
return nil, status.Errorf(codes.NotFound, "setup key not found")
}
// only auto groups, revoked status, and name can be updated for now
newKey := oldKey.Copy()
newKey.Name = keyToSave.Name
newKey.AutoGroups = keyToSave.AutoGroups
newKey.Revoked = keyToSave.Revoked
newKey.UpdatedAt = time.Now()
account.SetupKeys[newKey.Key] = newKey
if err = am.Store.SaveAccount(account); err != nil {
return nil, err
}
return newKey, am.updateAccountPeers(account)
}
// ListSetupKeys returns a list of all setup keys of the account
func (am *DefaultAccountManager) ListSetupKeys(accountID string) ([]*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
keys := make([]*SetupKey, 0, len(account.SetupKeys))
for _, key := range account.SetupKeys {
keys = append(keys, key.Copy())
}
return keys, nil
}
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
func (am *DefaultAccountManager) GetSetupKey(accountID, keyID string) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
var foundKey *SetupKey
for _, key := range account.SetupKeys {
if key.Id == keyID {
foundKey = key.Copy()
break
}
}
if foundKey == nil {
return nil, status.Errorf(codes.NotFound, "setup key not found")
}
// the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file)
if foundKey.UpdatedAt.IsZero() {
foundKey.UpdatedAt = foundKey.CreatedAt
}
return foundKey, nil
}

View File

@@ -2,23 +2,159 @@ package server
import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"strconv"
"testing"
"time"
)
func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
userID := "test_user"
account, err := manager.GetOrCreateAccountByUser(userID, "")
if err != nil {
t.Fatal(err)
}
err = manager.SaveGroup(account.Id, &Group{
ID: "group_1",
Name: "group_name_1",
Peers: []string{},
})
if err != nil {
t.Fatal(err)
}
expiresIn := time.Hour
keyName := "my-test-key"
key, err := manager.CreateSetupKey(account.Id, keyName, SetupKeyReusable, expiresIn, []string{})
if err != nil {
t.Fatal(err)
}
autoGroups := []string{"group_1", "group_2"}
newKeyName := "my-new-test-key"
revoked := true
newKey, err := manager.SaveSetupKey(account.Id, &SetupKey{
Id: key.Id,
Name: newKeyName,
Revoked: revoked,
AutoGroups: autoGroups,
})
if err != nil {
t.Fatal(err)
}
assertKey(t, newKey, newKeyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt,
key.Id, time.Now(), autoGroups)
}
func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
userID := "test_user"
account, err := manager.GetOrCreateAccountByUser(userID, "")
if err != nil {
t.Fatal(err)
}
err = manager.SaveGroup(account.Id, &Group{
ID: "group_1",
Name: "group_name_1",
Peers: []string{},
})
if err != nil {
t.Fatal(err)
}
err = manager.SaveGroup(account.Id, &Group{
ID: "group_2",
Name: "group_name_2",
Peers: []string{},
})
if err != nil {
t.Fatal(err)
}
type testCase struct {
name string
expectedKeyName string
expectedUsedTimes int
expectedType string
expectedGroups []string
expectedCreatedAt time.Time
expectedUpdatedAt time.Time
expectedExpiresAt time.Time
expectedFailure bool //indicates whether key creation should fail
}
now := time.Now()
expiresIn := time.Hour
testCase1 := testCase{
name: "Should Create Setup Key successfully",
expectedKeyName: "my-test-key",
expectedUsedTimes: 0,
expectedType: "reusable",
expectedGroups: []string{"group_1", "group_2"},
expectedCreatedAt: now,
expectedUpdatedAt: now,
expectedExpiresAt: now.Add(expiresIn),
expectedFailure: false,
}
testCase2 := testCase{
name: "Create Setup Key should fail because of unexistent group",
expectedKeyName: "my-test-key",
expectedGroups: []string{"FAKE"},
expectedFailure: true,
}
for _, tCase := range []testCase{testCase1, testCase2} {
t.Run(tCase.name, func(t *testing.T) {
key, err := manager.CreateSetupKey(account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn,
tCase.expectedGroups)
if tCase.expectedFailure {
if err == nil {
t.Fatal("expected to fail")
}
return
}
if err != nil {
t.Fatal(err)
}
assertKey(t, key, tCase.expectedKeyName, false, tCase.expectedType, tCase.expectedUsedTimes,
tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))),
tCase.expectedUpdatedAt, tCase.expectedGroups)
})
}
}
func TestGenerateDefaultSetupKey(t *testing.T) {
expectedName := "Default key"
expectedRevoke := false
expectedType := "reusable"
expectedUsedTimes := 0
expectedCreatedAt := time.Now()
expectedUpdatedAt := time.Now()
expectedExpiresAt := time.Now().Add(24 * 30 * time.Hour)
var expectedAutoGroups []string
key := GenerateDefaultSetupKey()
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt,
expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))))
expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), expectedUpdatedAt, expectedAutoGroups)
}
@@ -29,41 +165,44 @@ func TestGenerateSetupKey(t *testing.T) {
expectedUsedTimes := 0
expectedCreatedAt := time.Now()
expectedExpiresAt := time.Now().Add(time.Hour)
expectedUpdatedAt := time.Now()
var expectedAutoGroups []string
key := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour)
key := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour, []string{})
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))))
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt,
expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), expectedUpdatedAt, expectedAutoGroups)
}
func TestSetupKey_IsValid(t *testing.T) {
validKey := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour)
validKey := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour, []string{})
if !validKey.IsValid() {
t.Errorf("expected key to be valid, got invalid %v", validKey)
}
// expired
expiredKey := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour)
expiredKey := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour, []string{})
if expiredKey.IsValid() {
t.Errorf("expected key to be invalid due to expiration, got valid %v", expiredKey)
}
// revoked
revokedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour)
revokedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{})
revokedKey.Revoked = true
if revokedKey.IsValid() {
t.Errorf("expected revoked key to be invalid, got valid %v", revokedKey)
}
// overused
overUsedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour)
overUsedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{})
overUsedKey.UsedTimes = 1
if overUsedKey.IsValid() {
t.Errorf("expected overused key to be invalid, got valid %v", overUsedKey)
}
// overused
reusableKey := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour)
reusableKey := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour, []string{})
reusableKey.UsedTimes = 99
if !reusableKey.IsValid() {
t.Errorf("expected reusable key to be valid when used many times, got valid %v", reusableKey)
@@ -71,7 +210,8 @@ func TestSetupKey_IsValid(t *testing.T) {
}
func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke bool, expectedType string,
expectedUsedTimes int, expectedCreatedAt time.Time, expectedExpiresAt time.Time, expectedID string) {
expectedUsedTimes int, expectedCreatedAt time.Time, expectedExpiresAt time.Time, expectedID string,
expectedUpdatedAt time.Time, expectedAutoGroups []string) {
if key.Name != expectedName {
t.Errorf("expected setup key to have Name %v, got %v", expectedName, key.Name)
}
@@ -92,6 +232,10 @@ func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke
t.Errorf("expected setup key to have ExpiresAt ~ %v, got %v", expectedExpiresAt, key.ExpiresAt)
}
if key.UpdatedAt.Sub(expectedUpdatedAt).Round(time.Hour) != 0 {
t.Errorf("expected setup key to have UpdatedAt ~ %v, got %v", expectedUpdatedAt, key.UpdatedAt)
}
if key.CreatedAt.Sub(expectedCreatedAt).Round(time.Hour) != 0 {
t.Errorf("expected setup key to have CreatedAt ~ %v, got %v", expectedCreatedAt, key.CreatedAt)
}
@@ -104,13 +248,19 @@ func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke
if key.Id != strconv.Itoa(int(Hash(key.Key))) {
t.Errorf("expected key Id t= %v, got %v", expectedID, key.Id)
}
if len(key.AutoGroups) != len(expectedAutoGroups) {
t.Errorf("expected key AutoGroups size=%d, got %d", len(expectedAutoGroups), len(key.AutoGroups))
}
assert.ElementsMatch(t, key.AutoGroups, expectedAutoGroups, "expected key AutoGroups to be equal")
}
func TestSetupKey_Copy(t *testing.T) {
key := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour)
key := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour, []string{})
keyCopy := key.Copy()
assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.ExpiresAt, key.Id)
assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.ExpiresAt, key.Id,
key.UpdatedAt, key.AutoGroups)
}