mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-14 05:06:32 -04:00
Compare commits
7 Commits
wg_bind_pa
...
fix/wg_con
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c76e009723 | ||
|
|
6630ec013d | ||
|
|
49018a4a8e | ||
|
|
91636feb69 | ||
|
|
6fec0c682e | ||
|
|
c2e90a2a97 | ||
|
|
118880b6f7 |
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
)
|
||||
|
||||
var loginCmd = &cobra.Command{
|
||||
@@ -32,6 +33,11 @@ var loginCmd = &cobra.Command{
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
|
||||
if hostName != "" {
|
||||
// nolint
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName)
|
||||
}
|
||||
|
||||
// workaround to run without service
|
||||
if logFile == "console" {
|
||||
err = handleRebrand(cmd)
|
||||
|
||||
@@ -45,6 +45,7 @@ var (
|
||||
managementURL string
|
||||
adminURL string
|
||||
setupKey string
|
||||
hostName string
|
||||
preSharedKey string
|
||||
natExternalIPs []string
|
||||
customDNSAddress string
|
||||
@@ -94,6 +95,7 @@ func init() {
|
||||
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the the log will be output to stdout")
|
||||
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
|
||||
rootCmd.PersistentFlags().StringVar(&preSharedKey, "preshared-key", "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
|
||||
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
|
||||
rootCmd.AddCommand(serviceCmd)
|
||||
rootCmd.AddCommand(upCmd)
|
||||
rootCmd.AddCommand(downCmd)
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -55,6 +56,11 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
if hostName != "" {
|
||||
// nolint
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName)
|
||||
}
|
||||
|
||||
if foregroundMode {
|
||||
return runInForegroundMode(ctx, cmd)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/iface/bind"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
@@ -326,12 +327,7 @@ func (conn *Conn) Open() error {
|
||||
// * Local peer uses userspace interface with bind.ICEBind and is not relayed
|
||||
//
|
||||
// Please note, that this check happens when peers were already able to ping each other using ICE layer.
|
||||
func shouldUseProxy(pair *ice.CandidatePair, userspaceBind bool) bool {
|
||||
|
||||
if !isRelayCandidate(pair.Local) && userspaceBind {
|
||||
log.Debugf("shouldn't use proxy because using Bind and the connection is not relayed")
|
||||
return false
|
||||
}
|
||||
func shouldUseProxy(pair *ice.CandidatePair) bool {
|
||||
|
||||
if !isHardNATCandidate(pair.Local) && isHostCandidateWithPublicIP(pair.Remote) {
|
||||
log.Debugf("shouldn't use proxy because the local peer is not behind a hard NAT and the remote one has a public IP")
|
||||
@@ -436,26 +432,59 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
|
||||
}
|
||||
|
||||
func (conn *Conn) getProxyWithMessageExchange(pair *ice.CandidatePair, remoteWgPort int) proxy.Proxy {
|
||||
useProxy := shouldUseProxy(pair, conn.config.UserspaceBind)
|
||||
localDirectMode := !useProxy
|
||||
remoteDirectMode := localDirectMode
|
||||
|
||||
if conn.meta.protoSupport.DirectCheck {
|
||||
go conn.sendLocalDirectMode(localDirectMode)
|
||||
// will block until message received or timeout
|
||||
remoteDirectMode = conn.receiveRemoteDirectMode()
|
||||
if !conn.config.UserspaceBind {
|
||||
useProxy := shouldUseProxy(pair)
|
||||
localDirectMode := !useProxy
|
||||
remoteDirectMode := localDirectMode
|
||||
|
||||
if conn.meta.protoSupport.DirectCheck {
|
||||
go conn.sendLocalDirectMode(localDirectMode)
|
||||
// will block until message received or timeout
|
||||
remoteDirectMode = conn.receiveRemoteDirectMode()
|
||||
}
|
||||
|
||||
if localDirectMode && remoteDirectMode {
|
||||
return proxy.NewDirectNoProxy(conn.config.ProxyConfig, remoteWgPort)
|
||||
}
|
||||
|
||||
log.Debugf("falling back to local proxy mode with peer %s", conn.config.Key)
|
||||
return proxy.NewWireGuardProxy(conn.config.ProxyConfig)
|
||||
}
|
||||
|
||||
if conn.config.UserspaceBind && localDirectMode {
|
||||
return proxy.NewNoProxy(conn.config.ProxyConfig)
|
||||
if isRelayCandidate(pair.Local) {
|
||||
return proxy.NewWireGuardProxy(conn.config.ProxyConfig)
|
||||
}
|
||||
|
||||
if localDirectMode && remoteDirectMode {
|
||||
return proxy.NewDirectNoProxy(conn.config.ProxyConfig, remoteWgPort)
|
||||
// We decided to ignore the proxy decision when using Bind. Instead, we always punch remote WireGuard port to open a
|
||||
// hole in the firewall for that remote port to avoid cases when old clients assumes direct mode.
|
||||
mux := conn.config.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault)
|
||||
go func() {
|
||||
err := punchRemote(pair, remoteWgPort, mux)
|
||||
if err != nil {
|
||||
log.Warnf("failed to punch remote WireGuard port: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return proxy.NewNoProxy(conn.config.ProxyConfig)
|
||||
|
||||
}
|
||||
|
||||
func punchRemote(pair *ice.CandidatePair, remoteWgPort int, muxDefault *bind.UniversalUDPMuxDefault) error {
|
||||
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pair.Remote.Address(), remoteWgPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("falling back to local proxy mode with peer %s", conn.config.Key)
|
||||
return proxy.NewWireGuardProxy(conn.config.ProxyConfig)
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err = muxDefault.GetSharedConn().WriteTo([]byte{1}, addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debugf("puch msg has been sent: %d, %s, %v", i, addr.String(), muxDefault.GetSharedConn().LocalAddr())
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (conn *Conn) sendLocalDirectMode(localMode bool) {
|
||||
|
||||
@@ -78,6 +78,7 @@ func (d *Status) ReplaceOfflinePeers(replacement []State) {
|
||||
defer d.mux.Unlock()
|
||||
d.offlinePeers = make([]State, len(replacement))
|
||||
copy(d.offlinePeers, replacement)
|
||||
d.notifyPeerListChanged()
|
||||
}
|
||||
|
||||
// AddPeer adds peer to Daemon status map
|
||||
@@ -308,7 +309,7 @@ func (d *Status) onConnectionChanged() {
|
||||
}
|
||||
|
||||
func (d *Status) notifyPeerListChanged() {
|
||||
d.notifier.peerListChanged(len(d.peers))
|
||||
d.notifier.peerListChanged(len(d.peers) + len(d.offlinePeers))
|
||||
}
|
||||
|
||||
func (d *Status) notifyAddressChanged() {
|
||||
|
||||
@@ -43,6 +43,15 @@ func extractUserAgent(ctx context.Context) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractDeviceName extracts device name from context or returns the default system name
|
||||
func extractDeviceName(ctx context.Context, defaultName string) string {
|
||||
v, ok := ctx.Value(DeviceNameCtxKey).(string)
|
||||
if !ok {
|
||||
return defaultName
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// GetDesktopUIUserAgent returns the Desktop ui user agent
|
||||
func GetDesktopUIUserAgent() string {
|
||||
return "netbird-desktop-ui/" + version.NetbirdVersion()
|
||||
|
||||
@@ -24,21 +24,13 @@ func GetInfo(ctx context.Context) *Info {
|
||||
}
|
||||
|
||||
gio := &Info{Kernel: kernel, Core: osVersion(), Platform: "unknown", OS: "android", OSVersion: osVersion(), GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
|
||||
gio.Hostname = extractDeviceName(ctx)
|
||||
gio.Hostname = extractDeviceName(ctx, "android")
|
||||
gio.WiretrusteeVersion = version.NetbirdVersion()
|
||||
gio.UIVersion = extractUserAgent(ctx)
|
||||
|
||||
return gio
|
||||
}
|
||||
|
||||
func extractDeviceName(ctx context.Context) string {
|
||||
v, ok := ctx.Value(DeviceNameCtxKey).(string)
|
||||
if !ok {
|
||||
return "android"
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func uname() []string {
|
||||
res := run("/system/bin/uname", "-a")
|
||||
return strings.Split(res, " ")
|
||||
|
||||
@@ -32,7 +32,8 @@ func GetInfo(ctx context.Context) *Info {
|
||||
swVersion = []byte(release)
|
||||
}
|
||||
gio := &Info{Kernel: sysName, OSVersion: strings.TrimSpace(string(swVersion)), Core: release, Platform: machine, OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
|
||||
gio.Hostname, _ = os.Hostname()
|
||||
systemHostname, _ := os.Hostname()
|
||||
gio.Hostname = extractDeviceName(ctx, systemHostname)
|
||||
gio.WiretrusteeVersion = version.NetbirdVersion()
|
||||
gio.UIVersion = extractUserAgent(ctx)
|
||||
|
||||
|
||||
@@ -24,7 +24,8 @@ func GetInfo(ctx context.Context) *Info {
|
||||
osStr = strings.Replace(osStr, "\r\n", "", -1)
|
||||
osInfo := strings.Split(osStr, " ")
|
||||
gio := &Info{Kernel: osInfo[0], Core: osInfo[1], Platform: runtime.GOARCH, OS: osInfo[2], GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
|
||||
gio.Hostname, _ = os.Hostname()
|
||||
systemHostname, _ := os.Hostname()
|
||||
gio.Hostname = extractDeviceName(ctx, systemHostname)
|
||||
gio.WiretrusteeVersion = version.NetbirdVersion()
|
||||
gio.UIVersion = extractUserAgent(ctx)
|
||||
|
||||
|
||||
@@ -50,7 +50,8 @@ func GetInfo(ctx context.Context) *Info {
|
||||
osName = osInfo[3]
|
||||
}
|
||||
gio := &Info{Kernel: osInfo[0], Core: osInfo[1], Platform: osInfo[2], OS: osName, OSVersion: osVer, GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
|
||||
gio.Hostname, _ = os.Hostname()
|
||||
systemHostname, _ := os.Hostname()
|
||||
gio.Hostname = extractDeviceName(ctx, systemHostname)
|
||||
gio.WiretrusteeVersion = version.NetbirdVersion()
|
||||
gio.UIVersion = extractUserAgent(ctx)
|
||||
|
||||
|
||||
@@ -24,3 +24,12 @@ func Test_UIVersion(t *testing.T) {
|
||||
got := GetInfo(ctx)
|
||||
assert.Equal(t, want, got.UIVersion)
|
||||
}
|
||||
|
||||
func Test_CustomHostname(t *testing.T) {
|
||||
// nolint
|
||||
ctx := context.WithValue(context.Background(), DeviceNameCtxKey, "custom-host")
|
||||
want := "custom-host"
|
||||
|
||||
got := GetInfo(ctx)
|
||||
assert.Equal(t, want, got.Hostname)
|
||||
}
|
||||
|
||||
@@ -16,7 +16,8 @@ import (
|
||||
func GetInfo(ctx context.Context) *Info {
|
||||
ver := getOSVersion()
|
||||
gio := &Info{Kernel: "windows", OSVersion: ver, Core: ver, Platform: "unknown", OS: "windows", GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
|
||||
gio.Hostname, _ = os.Hostname()
|
||||
systemHostname, _ := os.Hostname()
|
||||
gio.Hostname = extractDeviceName(ctx, systemHostname)
|
||||
gio.WiretrusteeVersion = version.NetbirdVersion()
|
||||
gio.UIVersion = extractUserAgent(ctx)
|
||||
|
||||
|
||||
2
go.mod
2
go.mod
@@ -19,7 +19,7 @@ require (
|
||||
github.com/vishvananda/netlink v1.1.0
|
||||
golang.org/x/crypto v0.7.0
|
||||
golang.org/x/sys v0.6.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230310135217-9e2f38602202
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de
|
||||
golang.zx2c4.com/wireguard/windows v0.5.1
|
||||
google.golang.org/grpc v1.52.3
|
||||
|
||||
2
go.sum
2
go.sum
@@ -887,8 +887,6 @@ golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+D
|
||||
golang.zx2c4.com/wireguard v0.0.0-20211129173154-2dd424e2d808/go.mod h1:TjUWrnD5ATh7bFvmm/ALEJZQ4ivKbETb6pmyj1vUoNI=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675 h1:/J/RVnr7ng4fWPRH3xa4WtBJ1Jp+Auu4YNLmGiPv5QU=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675/go.mod h1:whfbyDBt09xhCYQWtO2+3UVjlaq6/9hDZrjg2ZE6SyA=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230310135217-9e2f38602202 h1:KcD4X7IcoRdQpr9NSQpQpn5S4rUMIQaCdF90FOEj6KY=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230310135217-9e2f38602202/go.mod h1:qc3aHNhM1Rc4hW2az896MjLVcxHvLbJ6LZc9MI7RTMY=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de h1:qDZ+lyO5jC9RNJ7ANJA0GWXk3pSn0Fu5SlcAIlgw+6w=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de/go.mod h1:Q2XNgour4QSkFj0BWCkVlW0HWJwQgNMsMahpSlI0Eno=
|
||||
golang.zx2c4.com/wireguard/windows v0.5.1 h1:OnYw96PF+CsIMrqWo5QP3Q59q5hY1rFErk/yN3cS+JQ=
|
||||
|
||||
@@ -1,416 +1,95 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/pion/stun"
|
||||
"github.com/pion/transport/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
var (
|
||||
_ wgConn.Bind = (*ICEBind)(nil)
|
||||
)
|
||||
|
||||
// ICEBind implements Bind for all platforms except Windows.
|
||||
// ICEBind is the userspace implementation of WireGuard's conn.Bind interface using ice.UDPMux of the pion/ice library
|
||||
type ICEBind struct {
|
||||
mu sync.Mutex // protects following fields
|
||||
ipv4 *net.UDPConn
|
||||
ipv6 *net.UDPConn
|
||||
blackhole4 bool
|
||||
blackhole6 bool
|
||||
ipv4PC *ipv4.PacketConn
|
||||
ipv6PC *ipv6.PacketConn
|
||||
batchSize int
|
||||
udpAddrPool sync.Pool
|
||||
ipv4MsgsPool sync.Pool
|
||||
ipv6MsgsPool sync.Pool
|
||||
// below fields, initialized on open
|
||||
ipv4 net.PacketConn
|
||||
udpMux *UniversalUDPMuxDefault
|
||||
|
||||
// NetBird related variables
|
||||
// below are fields initialized on creation
|
||||
transportNet transport.Net
|
||||
udpMux *UniversalUDPMuxDefault
|
||||
worker *worker
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewICEBind create a new instance of ICEBind with a given transportNet function.
|
||||
// The transportNet can be nil.
|
||||
func NewICEBind(transportNet transport.Net) *ICEBind {
|
||||
b := &ICEBind{
|
||||
batchSize: wgConn.DefaultBatchSize,
|
||||
|
||||
udpAddrPool: sync.Pool{
|
||||
New: func() any {
|
||||
return &net.UDPAddr{
|
||||
IP: make([]byte, 16),
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
ipv4MsgsPool: sync.Pool{
|
||||
New: func() any {
|
||||
msgs := make([]ipv4.Message, wgConn.DefaultBatchSize)
|
||||
for i := range msgs {
|
||||
msgs[i].Buffers = make(net.Buffers, 1)
|
||||
msgs[i].OOB = make([]byte, srcControlSize)
|
||||
}
|
||||
return &msgs
|
||||
},
|
||||
},
|
||||
|
||||
ipv6MsgsPool: sync.Pool{
|
||||
New: func() any {
|
||||
msgs := make([]ipv6.Message, wgConn.DefaultBatchSize)
|
||||
for i := range msgs {
|
||||
msgs[i].Buffers = make(net.Buffers, 1)
|
||||
msgs[i].OOB = make([]byte, srcControlSize)
|
||||
}
|
||||
return &msgs
|
||||
},
|
||||
},
|
||||
return &ICEBind{
|
||||
transportNet: transportNet,
|
||||
}
|
||||
b.worker = newWorker(b.handlePkgs)
|
||||
return b
|
||||
}
|
||||
|
||||
type StdNetEndpoint struct {
|
||||
// AddrPort is the endpoint destination.
|
||||
netip.AddrPort
|
||||
// src is the current sticky source address and interface index, if supported.
|
||||
src struct {
|
||||
netip.Addr
|
||||
ifidx int32
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
_ wgConn.Bind = (*ICEBind)(nil)
|
||||
_ wgConn.Endpoint = &StdNetEndpoint{}
|
||||
)
|
||||
|
||||
func (*ICEBind) ParseEndpoint(s string) (wgConn.Endpoint, error) {
|
||||
e, err := netip.ParseAddrPort(s)
|
||||
return asEndpoint(e), err
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) ClearSrc() {
|
||||
e.src.ifidx = 0
|
||||
e.src.Addr = netip.Addr{}
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) DstIP() netip.Addr {
|
||||
return e.AddrPort.Addr()
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
||||
return e.src.Addr
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) SrcIfidx() int32 {
|
||||
return e.src.ifidx
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) DstToBytes() []byte {
|
||||
b, _ := e.AddrPort.MarshalBinary()
|
||||
return b
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) DstToString() string {
|
||||
return e.AddrPort.String()
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) SrcToString() string {
|
||||
return e.src.Addr.String()
|
||||
}
|
||||
|
||||
func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
||||
conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Retrieve port.
|
||||
laddr := conn.LocalAddr()
|
||||
uaddr, err := net.ResolveUDPAddr(
|
||||
laddr.Network(),
|
||||
laddr.String(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return conn.(*net.UDPConn), uaddr.Port, nil
|
||||
}
|
||||
|
||||
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var err error
|
||||
var tries int
|
||||
|
||||
if s.ipv4 != nil || s.ipv6 != nil {
|
||||
return nil, 0, wgConn.ErrBindAlreadyOpen
|
||||
}
|
||||
|
||||
// Attempt to open ipv4 and ipv6 listeners on the same port.
|
||||
// If uport is 0, we can retry on failure.
|
||||
again:
|
||||
port := int(uport)
|
||||
var v4conn, v6conn *net.UDPConn
|
||||
|
||||
v4conn, port, err = listenNet("udp4", port)
|
||||
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Listen on the same port as we're using for ipv4.
|
||||
v6conn, port, err = listenNet("udp6", port)
|
||||
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
|
||||
v4conn.Close()
|
||||
tries++
|
||||
goto again
|
||||
}
|
||||
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
v4conn.Close()
|
||||
return nil, 0, err
|
||||
}
|
||||
var fns []wgConn.ReceiveFunc
|
||||
if v4conn != nil {
|
||||
fns = append(fns, s.receiveIPv4)
|
||||
s.ipv4 = v4conn
|
||||
}
|
||||
if v6conn != nil {
|
||||
fns = append(fns, s.receiveIPv6)
|
||||
s.ipv6 = v6conn
|
||||
}
|
||||
if len(fns) == 0 {
|
||||
return nil, 0, syscall.EAFNOSUPPORT
|
||||
}
|
||||
|
||||
s.ipv4PC = ipv4.NewPacketConn(s.ipv4)
|
||||
s.ipv6PC = ipv6.NewPacketConn(s.ipv6)
|
||||
|
||||
s.udpMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: s.ipv4, Net: s.transportNet})
|
||||
return fns, uint16(port), nil
|
||||
}
|
||||
|
||||
func (s *ICEBind) receiveIPv4(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
|
||||
defer s.ipv4MsgsPool.Put(msgs)
|
||||
for i := range buffs {
|
||||
(*msgs)[i].Buffers[0] = buffs[i]
|
||||
}
|
||||
numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
s.worker.doWork((*msgs)[:numMsgs], sizes, eps)
|
||||
return numMsgs, nil
|
||||
}
|
||||
|
||||
func (s *ICEBind) receiveIPv6(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
|
||||
defer s.ipv6MsgsPool.Put(msgs)
|
||||
for i := range buffs {
|
||||
(*msgs)[i].Buffers[0] = buffs[i]
|
||||
}
|
||||
numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for i := 0; i < numMsgs; i++ {
|
||||
msg := &(*msgs)[i]
|
||||
sizes[i] = msg.N
|
||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||
ep := asEndpoint(addrPort)
|
||||
getSrcFromControl(msg.OOB, ep)
|
||||
eps[i] = ep
|
||||
}
|
||||
return numMsgs, nil
|
||||
}
|
||||
|
||||
func (s *ICEBind) BatchSize() int {
|
||||
return s.batchSize
|
||||
}
|
||||
|
||||
func (s *ICEBind) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var err1, err2 error
|
||||
if s.ipv4 != nil {
|
||||
err1 = s.ipv4.Close()
|
||||
s.ipv4 = nil
|
||||
}
|
||||
if s.ipv6 != nil {
|
||||
err2 = s.ipv6.Close()
|
||||
s.ipv6 = nil
|
||||
}
|
||||
s.blackhole4 = false
|
||||
s.blackhole6 = false
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
|
||||
func (s *ICEBind) Send(buffs [][]byte, endpoint wgConn.Endpoint) error {
|
||||
s.mu.Lock()
|
||||
blackhole := s.blackhole4
|
||||
conn := s.ipv4
|
||||
is6 := false
|
||||
if endpoint.DstIP().Is6() {
|
||||
blackhole = s.blackhole6
|
||||
conn = s.ipv6
|
||||
is6 = true
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if blackhole {
|
||||
return nil
|
||||
}
|
||||
if conn == nil {
|
||||
return syscall.EAFNOSUPPORT
|
||||
}
|
||||
if is6 {
|
||||
return s.send6(s.ipv6PC, endpoint, buffs)
|
||||
} else {
|
||||
return s.send4(s.ipv4PC, endpoint, buffs)
|
||||
mu: sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
||||
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.udpMux == nil {
|
||||
func (b *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.udpMux == nil {
|
||||
return nil, fmt.Errorf("ICEBind has not been initialized yet")
|
||||
}
|
||||
|
||||
return s.udpMux, nil
|
||||
return b.udpMux, nil
|
||||
}
|
||||
|
||||
func (s *ICEBind) send4(conn *ipv4.PacketConn, ep wgConn.Endpoint, buffs [][]byte) error {
|
||||
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
||||
as4 := ep.DstIP().As4()
|
||||
copy(ua.IP, as4[:])
|
||||
ua.IP = ua.IP[:4]
|
||||
ua.Port = int(ep.(*StdNetEndpoint).Port())
|
||||
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
|
||||
for i, buff := range buffs {
|
||||
(*msgs)[i].Buffers[0] = buff
|
||||
(*msgs)[i].Addr = ua
|
||||
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
|
||||
// Open creates a WireGuard socket and an instance of UDPMux that is used to glue up ICE and WireGuard for hole punching
|
||||
func (b *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.ipv4 != nil {
|
||||
return nil, 0, conn.ErrBindAlreadyOpen
|
||||
}
|
||||
var (
|
||||
n int
|
||||
err error
|
||||
start int
|
||||
|
||||
var err error
|
||||
b.ipv4, _, err = listenNet("udp4", int(uport))
|
||||
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
b.udpMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: b.ipv4, Net: b.transportNet})
|
||||
|
||||
portAddr, err := netip.ParseAddrPort(b.ipv4.LocalAddr().String())
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
log.Infof("opened ICEBind on %s", b.ipv4.LocalAddr().String())
|
||||
|
||||
return []conn.ReceiveFunc{
|
||||
b.makeReceiveIPv4(b.ipv4),
|
||||
},
|
||||
portAddr.Port(), nil
|
||||
}
|
||||
|
||||
func listenNet(network string, port int) (net.PacketConn, int, error) {
|
||||
c, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
lAddr := c.LocalAddr()
|
||||
uAddr, err := net.ResolveUDPAddr(
|
||||
lAddr.Network(),
|
||||
lAddr.String(),
|
||||
)
|
||||
for {
|
||||
n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
|
||||
if err != nil || n == len((*msgs)[start:len(buffs)]) {
|
||||
break
|
||||
}
|
||||
start += n
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
s.udpAddrPool.Put(ua)
|
||||
s.ipv4MsgsPool.Put(msgs)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *ICEBind) send6(conn *ipv6.PacketConn, ep wgConn.Endpoint, buffs [][]byte) error {
|
||||
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
||||
as16 := ep.DstIP().As16()
|
||||
copy(ua.IP, as16[:])
|
||||
ua.IP = ua.IP[:16]
|
||||
ua.Port = int(ep.(*StdNetEndpoint).Port())
|
||||
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
|
||||
for i, buff := range buffs {
|
||||
(*msgs)[i].Buffers[0] = buff
|
||||
(*msgs)[i].Addr = ua
|
||||
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
|
||||
}
|
||||
var (
|
||||
n int
|
||||
err error
|
||||
start int
|
||||
)
|
||||
for {
|
||||
n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
|
||||
if err != nil || n == len((*msgs)[start:len(buffs)]) {
|
||||
break
|
||||
}
|
||||
start += n
|
||||
}
|
||||
s.udpAddrPool.Put(ua)
|
||||
s.ipv6MsgsPool.Put(msgs)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
|
||||
for _, buffer := range buffers {
|
||||
if !stun.IsMessage(buffer) {
|
||||
continue
|
||||
}
|
||||
|
||||
msg, err := parseSTUNMessage(buffer[:n])
|
||||
if err != nil {
|
||||
buffer = []byte{}
|
||||
return true, err
|
||||
}
|
||||
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
|
||||
if muxErr != nil {
|
||||
log.Warnf("failed to handle packet")
|
||||
}
|
||||
buffer = []byte{}
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *ICEBind) handlePkgs(msg *ipv4.Message) (int, *StdNetEndpoint) {
|
||||
// todo: handle err
|
||||
size := 0
|
||||
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
|
||||
if !ok {
|
||||
size = msg.N
|
||||
}
|
||||
|
||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||
ep := asEndpoint(addrPort)
|
||||
getSrcFromControl(msg.OOB, ep)
|
||||
return size, ep
|
||||
}
|
||||
|
||||
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
|
||||
// This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates,
|
||||
// but Endpoints are immutable, so we can re-use them.
|
||||
var endpointPool = sync.Pool{
|
||||
New: func() any {
|
||||
return make(map[netip.AddrPort]*StdNetEndpoint)
|
||||
},
|
||||
}
|
||||
|
||||
// asEndpoint returns an Endpoint containing ap.
|
||||
func asEndpoint(ap netip.AddrPort) *StdNetEndpoint {
|
||||
m := endpointPool.Get().(map[netip.AddrPort]*StdNetEndpoint)
|
||||
defer endpointPool.Put(m)
|
||||
e, ok := m[ap]
|
||||
if !ok {
|
||||
e = &StdNetEndpoint{AddrPort: ap}
|
||||
m[ap] = e
|
||||
}
|
||||
return e
|
||||
return c, uAddr.Port, nil
|
||||
}
|
||||
|
||||
func parseSTUNMessage(raw []byte) (*stun.Message, error) {
|
||||
@@ -423,3 +102,107 @@ func parseSTUNMessage(raw []byte) (*stun.Message, error) {
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (b *ICEBind) makeReceiveIPv4(c net.PacketConn) conn.ReceiveFunc {
|
||||
return func(buff []byte) (int, conn.Endpoint, error) {
|
||||
n, endpoint, err := c.ReadFrom(buff)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
e, err := netip.ParseAddrPort(endpoint.String())
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
if !stun.IsMessage(buff) {
|
||||
// WireGuard traffic
|
||||
return n, (conn.StdNetEndpoint)(netip.AddrPortFrom(e.Addr(), e.Port())), nil
|
||||
}
|
||||
|
||||
msg, err := parseSTUNMessage(buff[:n])
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
err = b.udpMux.HandleSTUNMessage(msg, endpoint)
|
||||
if err != nil {
|
||||
log.Warnf("failed to handle packet")
|
||||
}
|
||||
|
||||
// discard packets because they are STUN related
|
||||
return 0, nil, nil //todo proper return
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the WireGuard socket and UDPMux
|
||||
func (b *ICEBind) Close() error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
var err1, err2 error
|
||||
if b.ipv4 != nil {
|
||||
c := b.ipv4
|
||||
b.ipv4 = nil
|
||||
err1 = c.Close()
|
||||
}
|
||||
|
||||
if b.udpMux != nil {
|
||||
m := b.udpMux
|
||||
b.udpMux = nil
|
||||
err2 = m.Close()
|
||||
}
|
||||
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
|
||||
return err2
|
||||
}
|
||||
|
||||
// SetMark sets the mark for each packet sent through this Bind.
|
||||
// This mark is passed to the kernel as the socket option SO_MARK.
|
||||
func (b *ICEBind) SetMark(mark uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send bytes to the remote endpoint (peer)
|
||||
func (b *ICEBind) Send(buff []byte, endpoint conn.Endpoint) error {
|
||||
|
||||
nend, ok := endpoint.(conn.StdNetEndpoint)
|
||||
if !ok {
|
||||
return conn.ErrWrongEndpointType
|
||||
}
|
||||
addrPort := netip.AddrPort(nend)
|
||||
_, err := b.ipv4.WriteTo(buff, &net.UDPAddr{
|
||||
IP: addrPort.Addr().AsSlice(),
|
||||
Port: int(addrPort.Port()),
|
||||
Zone: addrPort.Addr().Zone(),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// ParseEndpoint creates a new endpoint from a string.
|
||||
func (b *ICEBind) ParseEndpoint(s string) (ep conn.Endpoint, err error) {
|
||||
e, err := netip.ParseAddrPort(s)
|
||||
return asEndpoint(e), err
|
||||
}
|
||||
|
||||
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
|
||||
// This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates,
|
||||
// but Endpoints are immutable, so we can re-use them.
|
||||
var endpointPool = sync.Pool{
|
||||
New: func() any {
|
||||
return make(map[netip.AddrPort]conn.Endpoint)
|
||||
},
|
||||
}
|
||||
|
||||
// asEndpoint returns an Endpoint containing ap.
|
||||
func asEndpoint(ap netip.AddrPort) conn.Endpoint {
|
||||
m := endpointPool.Get().(map[netip.AddrPort]conn.Endpoint)
|
||||
defer endpointPool.Put(m)
|
||||
e, ok := m[ap]
|
||||
if !ok {
|
||||
e = conn.Endpoint(conn.StdNetEndpoint(ap))
|
||||
m[ap] = e
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package bind
|
||||
|
||||
import (
|
||||
"net"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// controlFn is the callback function signature from net.ListenConfig.Control.
|
||||
// It is used to apply platform specific configuration to the socket prior to
|
||||
// bind.
|
||||
type controlFn func(network, address string, c syscall.RawConn) error
|
||||
|
||||
// controlFns is a list of functions that are called from the listen config
|
||||
// that can apply socket options.
|
||||
var controlFns = []controlFn{}
|
||||
|
||||
// listenConfig returns a net.ListenConfig that applies the controlFns to the
|
||||
// socket prior to bind. This is used to apply socket buffer sizing and packet
|
||||
// information OOB configuration for sticky sockets.
|
||||
func listenConfig() *net.ListenConfig {
|
||||
return &net.ListenConfig{
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
for _, fn := range controlFns {
|
||||
if err := fn(network, address, c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package bind
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func init() {
|
||||
controlFns = append(controlFns,
|
||||
|
||||
// Enable receiving of the packet information (IP_PKTINFO for IPv4,
|
||||
// IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
var err error
|
||||
switch network {
|
||||
case "udp4":
|
||||
c.Control(func(fd uintptr) {
|
||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
|
||||
})
|
||||
case "udp6":
|
||||
c.Control(func(fd uintptr) {
|
||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
|
||||
})
|
||||
default:
|
||||
err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
|
||||
}
|
||||
return err
|
||||
},
|
||||
)
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
//go:build !windows && !linux && !js
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package bind
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func init() {
|
||||
controlFns = append(controlFns,
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
var err error
|
||||
if network == "udp6" {
|
||||
c.Control(func(fd uintptr) {
|
||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
|
||||
})
|
||||
}
|
||||
return err
|
||||
},
|
||||
)
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
//go:build !linux && !openbsd && !freebsd
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package bind
|
||||
|
||||
func (s *ICEBind) SetMark(mark uint32) error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
//go:build linux || openbsd || freebsd
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package bind
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var fwmarkIoctl int
|
||||
|
||||
func init() {
|
||||
switch runtime.GOOS {
|
||||
case "linux", "android":
|
||||
fwmarkIoctl = 36 /* unix.SO_MARK */
|
||||
case "freebsd":
|
||||
fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */
|
||||
case "openbsd":
|
||||
fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ICEBind) SetMark(mark uint32) error {
|
||||
var operr error
|
||||
if fwmarkIoctl == 0 {
|
||||
return nil
|
||||
}
|
||||
if s.ipv4 != nil {
|
||||
fd, err := s.ipv4.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = fd.Control(func(fd uintptr) {
|
||||
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
|
||||
})
|
||||
if err == nil {
|
||||
err = operr
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if s.ipv6 != nil {
|
||||
fd, err := s.ipv6.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = fd.Control(func(fd uintptr) {
|
||||
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
|
||||
})
|
||||
if err == nil {
|
||||
err = operr
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
//go:build !linux
|
||||
// +build !linux
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package bind
|
||||
|
||||
import wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
// TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but
|
||||
// use alternatively named flags and need ports and require testing.
|
||||
|
||||
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
||||
// the source information found.
|
||||
func getSrcFromControl(control []byte, ep *wgConn.StdNetEndpoint) {
|
||||
}
|
||||
|
||||
// setSrcControl parses the control for PKTINFO and if found updates ep with
|
||||
// the source information found.
|
||||
func setSrcControl(control *[]byte, ep *wgConn.StdNetEndpoint) {
|
||||
}
|
||||
|
||||
// srcControlSize returns the recommended buffer size for pooling sticky control
|
||||
// data.
|
||||
const srcControlSize = 0
|
||||
@@ -1,111 +0,0 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package bind
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
||||
// the source information found.
|
||||
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
|
||||
ep.ClearSrc()
|
||||
|
||||
var (
|
||||
hdr unix.Cmsghdr
|
||||
data []byte
|
||||
rem []byte = control
|
||||
err error
|
||||
)
|
||||
|
||||
for len(rem) > unix.SizeofCmsghdr {
|
||||
hdr, data, rem, err = unix.ParseOneSocketControlMessage(control)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if hdr.Level == unix.IPPROTO_IP &&
|
||||
hdr.Type == unix.IP_PKTINFO {
|
||||
|
||||
info := pktInfoFromBuf[unix.Inet4Pktinfo](data)
|
||||
ep.src.Addr = netip.AddrFrom4(info.Spec_dst)
|
||||
ep.src.ifidx = info.Ifindex
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if hdr.Level == unix.IPPROTO_IPV6 &&
|
||||
hdr.Type == unix.IPV6_PKTINFO {
|
||||
|
||||
info := pktInfoFromBuf[unix.Inet6Pktinfo](data)
|
||||
ep.src.Addr = netip.AddrFrom16(info.Addr)
|
||||
ep.src.ifidx = int32(info.Ifindex)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pktInfoFromBuf returns type T populated from the provided buf via copy(). It
|
||||
// panics if buf is of insufficient size.
|
||||
func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) {
|
||||
size := int(unsafe.Sizeof(t))
|
||||
if len(buf) < size {
|
||||
panic("pktInfoFromBuf: buffer too small")
|
||||
}
|
||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf)
|
||||
return t
|
||||
}
|
||||
|
||||
// setSrcControl parses the control for PKTINFO and if found updates ep with
|
||||
// the source information found.
|
||||
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
||||
*control = (*control)[:cap(*control)]
|
||||
if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
|
||||
*control = (*control)[:0]
|
||||
return
|
||||
}
|
||||
|
||||
if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() {
|
||||
*control = (*control)[:0]
|
||||
return
|
||||
}
|
||||
|
||||
if len(*control) < srcControlSize {
|
||||
*control = (*control)[:0]
|
||||
return
|
||||
}
|
||||
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
|
||||
if ep.SrcIP().Is4() {
|
||||
hdr.Level = unix.IPPROTO_IP
|
||||
hdr.Type = unix.IP_PKTINFO
|
||||
hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
|
||||
|
||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
|
||||
info.Ifindex = ep.src.ifidx
|
||||
if ep.SrcIP().IsValid() {
|
||||
info.Spec_dst = ep.SrcIP().As4()
|
||||
}
|
||||
} else {
|
||||
hdr.Level = unix.IPPROTO_IPV6
|
||||
hdr.Type = unix.IPV6_PKTINFO
|
||||
hdr.Len = unix.SizeofCmsghdr + unix.SizeofInet6Pktinfo
|
||||
|
||||
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
|
||||
info.Ifindex = uint32(ep.src.ifidx)
|
||||
if ep.SrcIP().IsValid() {
|
||||
info.Addr = ep.SrcIP().As16()
|
||||
}
|
||||
}
|
||||
|
||||
*control = (*control)[:hdr.Len]
|
||||
}
|
||||
|
||||
var srcControlSize = unix.CmsgLen(unix.SizeofInet6Pktinfo)
|
||||
@@ -1,207 +0,0 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package bind
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"testing"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func Test_setSrcControl(t *testing.T) {
|
||||
t.Run("IPv4", func(t *testing.T) {
|
||||
ep := &StdNetEndpoint{
|
||||
AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
|
||||
}
|
||||
ep.src.Addr = netip.MustParseAddr("127.0.0.1")
|
||||
ep.src.ifidx = 5
|
||||
|
||||
control := make([]byte, srcControlSize)
|
||||
|
||||
setSrcControl(&control, ep)
|
||||
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||
if hdr.Level != unix.IPPROTO_IP {
|
||||
t.Errorf("unexpected level: %d", hdr.Level)
|
||||
}
|
||||
if hdr.Type != unix.IP_PKTINFO {
|
||||
t.Errorf("unexpected type: %d", hdr.Type)
|
||||
}
|
||||
if hdr.Len != uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) {
|
||||
t.Errorf("unexpected length: %d", hdr.Len)
|
||||
}
|
||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||
if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 {
|
||||
t.Errorf("unexpected address: %v", info.Spec_dst)
|
||||
}
|
||||
if info.Ifindex != 5 {
|
||||
t.Errorf("unexpected ifindex: %d", info.Ifindex)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IPv6", func(t *testing.T) {
|
||||
ep := &StdNetEndpoint{
|
||||
AddrPort: netip.MustParseAddrPort("[::1]:1234"),
|
||||
}
|
||||
ep.src.Addr = netip.MustParseAddr("::1")
|
||||
ep.src.ifidx = 5
|
||||
|
||||
control := make([]byte, srcControlSize)
|
||||
|
||||
setSrcControl(&control, ep)
|
||||
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||
if hdr.Level != unix.IPPROTO_IPV6 {
|
||||
t.Errorf("unexpected level: %d", hdr.Level)
|
||||
}
|
||||
if hdr.Type != unix.IPV6_PKTINFO {
|
||||
t.Errorf("unexpected type: %d", hdr.Type)
|
||||
}
|
||||
if hdr.Len != uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) {
|
||||
t.Errorf("unexpected length: %d", hdr.Len)
|
||||
}
|
||||
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||
if info.Addr != ep.SrcIP().As16() {
|
||||
t.Errorf("unexpected address: %v", info.Addr)
|
||||
}
|
||||
if info.Ifindex != 5 {
|
||||
t.Errorf("unexpected ifindex: %d", info.Ifindex)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ClearOnNoSrc", func(t *testing.T) {
|
||||
control := make([]byte, srcControlSize)
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||
hdr.Level = 1
|
||||
hdr.Type = 2
|
||||
hdr.Len = 3
|
||||
|
||||
setSrcControl(&control, &StdNetEndpoint{})
|
||||
|
||||
if len(control) != 0 {
|
||||
t.Errorf("unexpected control: %v", control)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_getSrcFromControl(t *testing.T) {
|
||||
t.Run("IPv4", func(t *testing.T) {
|
||||
control := make([]byte, srcControlSize)
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||
hdr.Level = unix.IPPROTO_IP
|
||||
hdr.Type = unix.IP_PKTINFO
|
||||
hdr.Len = uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
|
||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||
info.Spec_dst = [4]byte{127, 0, 0, 1}
|
||||
info.Ifindex = 5
|
||||
|
||||
ep := &StdNetEndpoint{}
|
||||
getSrcFromControl(control, ep)
|
||||
|
||||
if ep.src.Addr != netip.MustParseAddr("127.0.0.1") {
|
||||
t.Errorf("unexpected address: %v", ep.src.Addr)
|
||||
}
|
||||
if ep.src.ifidx != 5 {
|
||||
t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
|
||||
}
|
||||
})
|
||||
t.Run("IPv6", func(t *testing.T) {
|
||||
control := make([]byte, srcControlSize)
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||
hdr.Level = unix.IPPROTO_IPV6
|
||||
hdr.Type = unix.IPV6_PKTINFO
|
||||
hdr.Len = uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{}))))
|
||||
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||
info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
|
||||
info.Ifindex = 5
|
||||
|
||||
ep := &StdNetEndpoint{}
|
||||
getSrcFromControl(control, ep)
|
||||
|
||||
if ep.SrcIP() != netip.MustParseAddr("::1") {
|
||||
t.Errorf("unexpected address: %v", ep.SrcIP())
|
||||
}
|
||||
if ep.src.ifidx != 5 {
|
||||
t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
|
||||
}
|
||||
})
|
||||
t.Run("ClearOnEmpty", func(t *testing.T) {
|
||||
control := make([]byte, srcControlSize)
|
||||
ep := &StdNetEndpoint{}
|
||||
ep.src.Addr = netip.MustParseAddr("::1")
|
||||
ep.src.ifidx = 5
|
||||
|
||||
getSrcFromControl(control, ep)
|
||||
if ep.SrcIP().IsValid() {
|
||||
t.Errorf("unexpected address: %v", ep.src.Addr)
|
||||
}
|
||||
if ep.src.ifidx != 0 {
|
||||
t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_listenConfig(t *testing.T) {
|
||||
t.Run("IPv4", func(t *testing.T) {
|
||||
conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
sc, err := conn.(*net.UDPConn).SyscallConn()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
var i int
|
||||
sc.Control(func(fd uintptr) {
|
||||
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if i != 1 {
|
||||
t.Error("IP_PKTINFO not set!")
|
||||
}
|
||||
} else {
|
||||
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
|
||||
}
|
||||
})
|
||||
t.Run("IPv6", func(t *testing.T) {
|
||||
conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sc, err := conn.(*net.UDPConn).SyscallConn()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
var i int
|
||||
sc.Control(func(fd uintptr) {
|
||||
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if i != 1 {
|
||||
t.Error("IPV6_PKTINFO not set!")
|
||||
}
|
||||
} else {
|
||||
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -75,6 +75,10 @@ type udpConn struct {
|
||||
logger logging.LeveledLogger
|
||||
}
|
||||
|
||||
func (m *UniversalUDPMuxDefault) GetSharedConn() net.PacketConn {
|
||||
return m.params.UDPConn
|
||||
}
|
||||
|
||||
// GetListenAddresses returns the listen addr of this UDP
|
||||
func (m *UniversalUDPMuxDefault) GetListenAddresses() []net.Addr {
|
||||
return []net.Addr{m.LocalAddr()}
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
// todo: add close function
|
||||
type worker struct {
|
||||
jobOffer chan int
|
||||
numOfWorker int
|
||||
|
||||
jobFn func(msg *ipv4.Message) (int, *StdNetEndpoint)
|
||||
|
||||
messages []ipv4.Message
|
||||
sizes []int
|
||||
eps []wgConn.Endpoint
|
||||
}
|
||||
|
||||
func newWorker(jobFn func(msg *ipv4.Message) (int, *StdNetEndpoint)) *worker {
|
||||
w := &worker{
|
||||
jobOffer: make(chan int),
|
||||
numOfWorker: runtime.NumCPU(),
|
||||
jobFn: jobFn,
|
||||
}
|
||||
|
||||
w.populateWorkers()
|
||||
return w
|
||||
}
|
||||
|
||||
func (w *worker) doWork(messages []ipv4.Message, sizes []int, eps []wgConn.Endpoint) {
|
||||
w.messages = messages
|
||||
w.sizes = sizes
|
||||
w.eps = eps
|
||||
|
||||
for i := 0; i < len(messages); i++ {
|
||||
w.jobOffer <- i
|
||||
}
|
||||
}
|
||||
|
||||
func (w *worker) populateWorkers() {
|
||||
for i := 0; i < w.numOfWorker; i++ {
|
||||
go w.loop()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *worker) loop() {
|
||||
for {
|
||||
select {
|
||||
case msgPos := <-w.jobOffer:
|
||||
w.sizes[msgPos], w.eps[msgPos] = w.jobFn(&w.messages[msgPos])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -52,7 +52,6 @@ func (t *tunDevice) Create() error {
|
||||
|
||||
log.Debugf("attaching to interface %v", name)
|
||||
t.device = device.NewDevice(tunDevice, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
|
||||
t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||
|
||||
err = t.device.Up()
|
||||
if err != nil {
|
||||
|
||||
@@ -49,15 +49,17 @@ type AccountManager interface {
|
||||
CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
|
||||
autoGroups []string, usageLimit int, userID string) (*SetupKey, error)
|
||||
SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error)
|
||||
CreateUser(accountID, userID string, key *UserInfo) (*UserInfo, error)
|
||||
CreateUser(accountID, executingUserID string, key *UserInfo) (*UserInfo, error)
|
||||
DeleteUser(accountID, executingUserID string, targetUserID string) error
|
||||
ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
|
||||
SaveUser(accountID, userID string, update *User) (*UserInfo, error)
|
||||
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
|
||||
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
|
||||
GetAccountByUserID(userID string) (*Account, error)
|
||||
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
|
||||
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
|
||||
MarkPATUsed(tokenID string) error
|
||||
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
|
||||
IsUserAdmin(userID string) (bool, error)
|
||||
AccountExists(accountId string) (*bool, error)
|
||||
GetPeerByKey(peerKey string) (*Peer, error)
|
||||
GetPeers(accountID, userID string) ([]*Peer, error)
|
||||
@@ -171,12 +173,13 @@ type Account struct {
|
||||
}
|
||||
|
||||
type UserInfo struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Role string `json:"role"`
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
Status string `json:"-"`
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Role string `json:"role"`
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
Status string `json:"-"`
|
||||
IsServiceUser bool `json:"is_service_user"`
|
||||
}
|
||||
|
||||
// getRoutesToSync returns the enabled routes for the peer ID and the routes
|
||||
@@ -1228,9 +1231,11 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
|
||||
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
||||
}
|
||||
|
||||
err = am.redeemInvite(account, claims.UserId)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
if !user.IsServiceUser {
|
||||
err = am.redeemInvite(account, claims.UserId)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return account, user, nil
|
||||
|
||||
@@ -87,6 +87,10 @@ const (
|
||||
PersonalAccessTokenCreated
|
||||
// PersonalAccessTokenDeleted indicates that a user deleted a personal access token
|
||||
PersonalAccessTokenDeleted
|
||||
// ServiceUserCreated indicates that a user created a service user
|
||||
ServiceUserCreated
|
||||
// ServiceUserDeleted indicates that a user deleted a service user
|
||||
ServiceUserDeleted
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -176,6 +180,10 @@ const (
|
||||
PersonalAccessTokenCreatedMessage string = "Personal access token created"
|
||||
// PersonalAccessTokenDeletedMessage is a human-readable text message of the PersonalAccessTokenDeleted activity
|
||||
PersonalAccessTokenDeletedMessage string = "Personal access token deleted"
|
||||
// ServiceUserCreatedMessage is a human-readable text message of the ServiceUserCreated activity
|
||||
ServiceUserCreatedMessage string = "Service user created"
|
||||
// ServiceUserDeletedMessage is a human-readable text message of the ServiceUserDeleted activity
|
||||
ServiceUserDeletedMessage string = "Service user deleted"
|
||||
)
|
||||
|
||||
// Activity that triggered an Event
|
||||
@@ -270,6 +278,10 @@ func (a Activity) Message() string {
|
||||
return PersonalAccessTokenCreatedMessage
|
||||
case PersonalAccessTokenDeleted:
|
||||
return PersonalAccessTokenDeletedMessage
|
||||
case ServiceUserCreated:
|
||||
return ServiceUserCreatedMessage
|
||||
case ServiceUserDeleted:
|
||||
return ServiceUserDeletedMessage
|
||||
default:
|
||||
return "UNKNOWN_ACTIVITY"
|
||||
}
|
||||
@@ -364,6 +376,10 @@ func (a Activity) StringCode() string {
|
||||
return "personal.access.token.create"
|
||||
case PersonalAccessTokenDeleted:
|
||||
return "personal.access.token.delete"
|
||||
case ServiceUserCreated:
|
||||
return "service.user.create"
|
||||
case ServiceUserDeleted:
|
||||
return "service.user.delete"
|
||||
default:
|
||||
return "UNKNOWN_ACTIVITY"
|
||||
}
|
||||
|
||||
@@ -77,6 +77,10 @@ components:
|
||||
description: Is true if authenticated user is the same as this user
|
||||
type: boolean
|
||||
readOnly: true
|
||||
is_service_user:
|
||||
description: Is true if this user is a service user
|
||||
type: boolean
|
||||
readOnly: true
|
||||
required:
|
||||
- id
|
||||
- email
|
||||
@@ -115,10 +119,13 @@ components:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
is_service_user:
|
||||
description: Is true if this user is a service user
|
||||
type: boolean
|
||||
required:
|
||||
- role
|
||||
- auto_groups
|
||||
- email
|
||||
- is_service_user
|
||||
PeerMinimum:
|
||||
type: object
|
||||
properties:
|
||||
@@ -825,6 +832,12 @@ paths:
|
||||
tags: [ Users ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
parameters:
|
||||
- in: query
|
||||
name: service_user
|
||||
schema:
|
||||
type: boolean
|
||||
description: Filters users and returns either normal users or service users
|
||||
responses:
|
||||
'200':
|
||||
description: A JSON array of Users
|
||||
@@ -903,6 +916,30 @@ paths:
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
delete:
|
||||
summary: Delete a User
|
||||
tags: [ Users ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The User ID
|
||||
responses:
|
||||
'200':
|
||||
description: Delete status code
|
||||
content: { }
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/users/{userId}/tokens:
|
||||
get:
|
||||
summary: Returns a list of all tokens for a user
|
||||
|
||||
@@ -677,6 +677,9 @@ type User struct {
|
||||
// IsCurrent Is true if authenticated user is the same as this user
|
||||
IsCurrent *bool `json:"is_current,omitempty"`
|
||||
|
||||
// IsServiceUser Is true if this user is a service user
|
||||
IsServiceUser *bool `json:"is_service_user,omitempty"`
|
||||
|
||||
// Name User's name from idp provider
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -696,7 +699,10 @@ type UserCreateRequest struct {
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// Email User's Email to send invite to
|
||||
Email string `json:"email"`
|
||||
Email *string `json:"email,omitempty"`
|
||||
|
||||
// IsServiceUser Is true if this user is a service user
|
||||
IsServiceUser bool `json:"is_service_user"`
|
||||
|
||||
// Name User's full name
|
||||
Name *string `json:"name,omitempty"`
|
||||
@@ -787,6 +793,12 @@ type PutApiRulesIdJSONBody struct {
|
||||
Sources *[]string `json:"sources,omitempty"`
|
||||
}
|
||||
|
||||
// GetApiUsersParams defines parameters for GetApiUsers.
|
||||
type GetApiUsersParams struct {
|
||||
// ServiceUser Filters users and returns either normal users or service users
|
||||
ServiceUser *bool `form:"service_user,omitempty" json:"service_user,omitempty"`
|
||||
}
|
||||
|
||||
// PutApiAccountsIdJSONRequestBody defines body for PutApiAccountsId for application/json ContentType.
|
||||
type PutApiAccountsIdJSONRequestBody PutApiAccountsIdJSONBody
|
||||
|
||||
|
||||
@@ -111,6 +111,7 @@ func (apiHandler *apiHandler) addUsersEndpoint() {
|
||||
userHandler := NewUsersHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
||||
apiHandler.Router.HandleFunc("/users", userHandler.GetAllUsers).Methods("GET", "OPTIONS")
|
||||
apiHandler.Router.HandleFunc("/users/{id}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
|
||||
apiHandler.Router.HandleFunc("/users/{id}", userHandler.DeleteUser).Methods("DELETE", "OPTIONS")
|
||||
apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS")
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
)
|
||||
|
||||
type IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
|
||||
type IsUserAdminFunc func(userID string) (bool, error)
|
||||
|
||||
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
|
||||
type AccessControl struct {
|
||||
@@ -37,7 +37,7 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
claims := a.claimsExtract.FromRequestContext(r)
|
||||
|
||||
ok, err := a.isUserAdmin(claims)
|
||||
ok, err := a.isUserAdmin(claims.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
|
||||
return
|
||||
|
||||
@@ -3,8 +3,10 @@ package http
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
@@ -77,6 +79,36 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId))
|
||||
}
|
||||
|
||||
// DeleteUser is a DELETE request to delete a user (only works for service users right now)
|
||||
func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodDelete {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["id"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteUser(account.Id, user.Id, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
}
|
||||
|
||||
// CreateUser creates a User in the system with a status "invited" (effectively this is a user invite).
|
||||
func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
@@ -103,11 +135,17 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
email := ""
|
||||
if req.Email != nil {
|
||||
email = *req.Email
|
||||
}
|
||||
|
||||
newUser, err := h.accountManager.CreateUser(account.Id, user.Id, &server.UserInfo{
|
||||
Email: req.Email,
|
||||
Name: *req.Name,
|
||||
Role: req.Role,
|
||||
AutoGroups: req.AutoGroups,
|
||||
Email: email,
|
||||
Name: *req.Name,
|
||||
Role: req.Role,
|
||||
AutoGroups: req.AutoGroups,
|
||||
IsServiceUser: req.IsServiceUser,
|
||||
})
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@@ -137,9 +175,27 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
serviceUser := r.URL.Query().Get("service_user")
|
||||
|
||||
log.Debugf("UserCount: %v", len(data))
|
||||
|
||||
users := make([]*api.User, 0)
|
||||
for _, r := range data {
|
||||
users = append(users, toUserResponse(r, claims.UserId))
|
||||
if serviceUser == "" {
|
||||
users = append(users, toUserResponse(r, claims.UserId))
|
||||
continue
|
||||
}
|
||||
includeServiceUser, err := strconv.ParseBool(serviceUser)
|
||||
log.Debugf("Should include service user: %v", includeServiceUser)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w)
|
||||
return
|
||||
}
|
||||
log.Debugf("User %v is service user: %v", r.Name, r.IsServiceUser)
|
||||
if includeServiceUser == r.IsServiceUser {
|
||||
log.Debugf("Found service user: %v", r.Name)
|
||||
users = append(users, toUserResponse(r, claims.UserId))
|
||||
}
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, users)
|
||||
@@ -163,12 +219,13 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
|
||||
|
||||
isCurrent := user.ID == currenUserID
|
||||
return &api.User{
|
||||
Id: user.ID,
|
||||
Name: user.Name,
|
||||
Email: user.Email,
|
||||
Role: user.Role,
|
||||
AutoGroups: autoGroups,
|
||||
Status: userStatus,
|
||||
IsCurrent: &isCurrent,
|
||||
Id: user.ID,
|
||||
Name: user.Name,
|
||||
Email: user.Email,
|
||||
Role: user.Role,
|
||||
AutoGroups: autoGroups,
|
||||
Status: userStatus,
|
||||
IsCurrent: &isCurrent,
|
||||
IsServiceUser: &user.IsServiceUser,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,52 +1,91 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/magiconair/properties/assert"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
func initUsers(user ...*server.User) *UsersHandler {
|
||||
const (
|
||||
serviceUserID = "serviceUserID"
|
||||
regularUserID = "regularUserID"
|
||||
)
|
||||
|
||||
var usersTestAccount = &server.Account{
|
||||
Id: existingAccountID,
|
||||
Domain: domain,
|
||||
Users: map[string]*server.User{
|
||||
existingUserID: {
|
||||
Id: existingUserID,
|
||||
Role: "admin",
|
||||
IsServiceUser: false,
|
||||
},
|
||||
regularUserID: {
|
||||
Id: regularUserID,
|
||||
Role: "user",
|
||||
IsServiceUser: false,
|
||||
},
|
||||
serviceUserID: {
|
||||
Id: serviceUserID,
|
||||
Role: "user",
|
||||
IsServiceUser: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func initUsersTestData() *UsersHandler {
|
||||
return &UsersHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
users := make(map[string]*server.User, 0)
|
||||
for _, u := range user {
|
||||
users[u.Id] = u
|
||||
}
|
||||
return &server.Account{
|
||||
Id: "12345",
|
||||
Domain: "netbird.io",
|
||||
Users: users,
|
||||
}, users[claims.UserId], nil
|
||||
return usersTestAccount, usersTestAccount.Users[claims.UserId], nil
|
||||
},
|
||||
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
|
||||
users := make([]*server.UserInfo, 0)
|
||||
for _, v := range user {
|
||||
for _, v := range usersTestAccount.Users {
|
||||
users = append(users, &server.UserInfo{
|
||||
ID: v.Id,
|
||||
Role: string(v.Role),
|
||||
Name: "",
|
||||
Email: "",
|
||||
ID: v.Id,
|
||||
Role: string(v.Role),
|
||||
Name: "",
|
||||
Email: "",
|
||||
IsServiceUser: v.IsServiceUser,
|
||||
})
|
||||
}
|
||||
return users, nil
|
||||
},
|
||||
CreateUserFunc: func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) {
|
||||
if userID != existingUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
|
||||
}
|
||||
return key, nil
|
||||
},
|
||||
DeleteUserFunc: func(accountID string, executingUserID string, targetUserID string) error {
|
||||
if targetUserID == notFoundUserID {
|
||||
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
|
||||
}
|
||||
if !usersTestAccount.Users[targetUserID].IsServiceUser {
|
||||
return status.Errorf(status.PermissionDenied, "user with ID %s is not a service user and can not be deleted", targetUserID)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: "1",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
UserId: existingUserID,
|
||||
Domain: domain,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
}),
|
||||
),
|
||||
@@ -54,8 +93,84 @@ func initUsers(user ...*server.User) *UsersHandler {
|
||||
}
|
||||
|
||||
func TestGetUsers(t *testing.T) {
|
||||
users := []*server.User{{Id: "1", Role: "admin"}, {Id: "2", Role: "user"}, {Id: "3", Role: "user"}}
|
||||
userHandler := initUsers(users...)
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
requestType string
|
||||
requestPath string
|
||||
expectedUserIDs []string
|
||||
}{
|
||||
{name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID, serviceUserID}},
|
||||
{name: "GetOnlyServiceUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=true", expectedStatus: http.StatusOK, expectedUserIDs: []string{serviceUserID}},
|
||||
{name: "GetOnlyRegularUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=false", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID}},
|
||||
}
|
||||
|
||||
userHandler := initUsersTestData()
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
|
||||
userHandler.GetAllUsers(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
content, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("I don't know what I expected; %v", err)
|
||||
}
|
||||
|
||||
if status := recorder.Code; status != tc.expectedStatus {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
|
||||
status, tc.expectedStatus, string(content))
|
||||
return
|
||||
}
|
||||
|
||||
respBody := []*server.UserInfo{}
|
||||
err = json.Unmarshal(content, &respBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, len(respBody), len(tc.expectedUserIDs))
|
||||
for _, v := range respBody {
|
||||
assert.Contains(t, tc.expectedUserIDs, v.ID)
|
||||
assert.Equal(t, v.ID, usersTestAccount.Users[v.ID].Id)
|
||||
assert.Equal(t, v.Role, string(usersTestAccount.Users[v.ID].Role))
|
||||
assert.Equal(t, v.IsServiceUser, usersTestAccount.Users[v.ID].IsServiceUser)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateUser(t *testing.T) {
|
||||
name := "name"
|
||||
email := "email"
|
||||
serviceUserToAdd := api.UserCreateRequest{
|
||||
AutoGroups: []string{},
|
||||
Email: nil,
|
||||
IsServiceUser: true,
|
||||
Name: &name,
|
||||
Role: "admin",
|
||||
}
|
||||
serviceUserString, err := json.Marshal(serviceUserToAdd)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
regularUserToAdd := api.UserCreateRequest{
|
||||
AutoGroups: []string{},
|
||||
Email: &email,
|
||||
IsServiceUser: true,
|
||||
Name: &name,
|
||||
Role: "admin",
|
||||
}
|
||||
regularUserString, err := json.Marshal(regularUserToAdd)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
@@ -65,40 +180,79 @@ func TestGetUsers(t *testing.T) {
|
||||
requestBody io.Reader
|
||||
expectedResult []*server.User
|
||||
}{
|
||||
{name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users/", expectedStatus: http.StatusOK, expectedResult: users},
|
||||
{name: "CreateServiceUser", requestType: http.MethodPost, requestPath: "/api/users", expectedStatus: http.StatusOK, requestBody: bytes.NewBuffer(serviceUserString)},
|
||||
// right now creation is blocked in AC middleware, will be refactored in the future
|
||||
{name: "CreateRegularUser", requestType: http.MethodPost, requestPath: "/api/users", expectedStatus: http.StatusOK, requestBody: bytes.NewBuffer(regularUserString)},
|
||||
}
|
||||
|
||||
userHandler := initUsersTestData()
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
userHandler.GetAllUsers(rr, req)
|
||||
userHandler.CreateUser(rr, req)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if status := rr.Code; status != tc.expectedStatus {
|
||||
t.Fatalf("handler returned wrong status code: got %v want %v",
|
||||
status, http.StatusOK)
|
||||
}
|
||||
|
||||
content, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
respBody := []*server.UserInfo{}
|
||||
err = json.Unmarshal(content, &respBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
if tc.expectedResult != nil {
|
||||
for i, resp := range respBody {
|
||||
assert.Equal(t, resp.ID, tc.expectedResult[i].Id)
|
||||
assert.Equal(t, string(resp.Role), string(tc.expectedResult[i].Role))
|
||||
}
|
||||
status, tc.expectedStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteUser(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
expectedBody bool
|
||||
requestType string
|
||||
requestPath string
|
||||
requestVars map[string]string
|
||||
requestBody io.Reader
|
||||
}{
|
||||
{
|
||||
name: "Delete Regular User",
|
||||
requestType: http.MethodDelete,
|
||||
requestPath: "/api/users/" + regularUserID,
|
||||
requestVars: map[string]string{"id": regularUserID},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "Delete Service User",
|
||||
requestType: http.MethodDelete,
|
||||
requestPath: "/api/users/" + serviceUserID,
|
||||
requestVars: map[string]string{"id": serviceUserID},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Delete Not Existing User",
|
||||
requestType: http.MethodDelete,
|
||||
requestPath: "/api/users/" + notFoundUserID,
|
||||
requestVars: map[string]string{"id": notFoundUserID},
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
userHandler := initUsersTestData()
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = mux.SetURLVars(req, tc.requestVars)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
userHandler.DeleteUser(rr, req)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if status := rr.Code; status != tc.expectedStatus {
|
||||
t.Fatalf("handler returned wrong status code: got %v want %v",
|
||||
status, tc.expectedStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -15,12 +15,12 @@ import (
|
||||
|
||||
type MockAccountManager struct {
|
||||
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
|
||||
GetAccountByUserFunc func(userId string) (*server.Account, error)
|
||||
GetAccountByUserIDFunc func(userID string) (*server.Account, error)
|
||||
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType,
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string) (*server.SetupKey, error)
|
||||
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
|
||||
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
||||
IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
|
||||
IsUserAdminFunc func(userID string) (bool, error)
|
||||
AccountExistsFunc func(accountId string) (*bool, error)
|
||||
GetPeerByKeyFunc func(peerKey string) (*server.Peer, error)
|
||||
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
|
||||
@@ -61,6 +61,7 @@ type MockAccountManager struct {
|
||||
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
|
||||
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
|
||||
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
|
||||
DeleteUserFunc func(accountID string, executingUserID string, targetUserID string) error
|
||||
CreatePATFunc func(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
||||
DeletePATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) error
|
||||
GetPATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
|
||||
@@ -112,12 +113,12 @@ func (am *MockAccountManager) GetOrCreateAccountByUser(
|
||||
)
|
||||
}
|
||||
|
||||
// GetAccountByUser mock implementation of GetAccountByUser from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountByUser(userId string) (*server.Account, error) {
|
||||
if am.GetAccountByUserFunc != nil {
|
||||
return am.GetAccountByUserFunc(userId)
|
||||
// GetAccountByUserID mock implementation of GetAccountByUserID from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountByUserID(userID string) (*server.Account, error) {
|
||||
if am.GetAccountByUserIDFunc != nil {
|
||||
return am.GetAccountByUserIDFunc(userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUser is not implemented")
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUserID is not implemented")
|
||||
}
|
||||
|
||||
// CreateSetupKey mock implementation of CreateSetupKey from server.AccountManager interface
|
||||
@@ -394,9 +395,9 @@ func (am *MockAccountManager) UpdatePeerMeta(peerID string, meta server.PeerSyst
|
||||
}
|
||||
|
||||
// IsUserAdmin mock implementation of IsUserAdmin from server.AccountManager interface
|
||||
func (am *MockAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) {
|
||||
func (am *MockAccountManager) IsUserAdmin(userID string) (bool, error) {
|
||||
if am.IsUserAdminFunc != nil {
|
||||
return am.IsUserAdminFunc(claims)
|
||||
return am.IsUserAdminFunc(userID)
|
||||
}
|
||||
return false, status.Errorf(codes.Unimplemented, "method IsUserAdmin is not implemented")
|
||||
}
|
||||
@@ -500,6 +501,14 @@ func (am *MockAccountManager) SaveUser(accountID, userID string, user *server.Us
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SaveUser is not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser mocks DeleteUser of the AccountManager interface
|
||||
func (am *MockAccountManager) DeleteUser(accountID string, executingUserID string, targetUserID string) error {
|
||||
if am.DeleteUserFunc != nil {
|
||||
return am.DeleteUserFunc(accountID, executingUserID, targetUserID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented")
|
||||
}
|
||||
|
||||
// GetNameServerGroup mocks GetNameServerGroup of the AccountManager interface
|
||||
func (am *MockAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||
if am.GetNameServerGroupFunc != nil {
|
||||
|
||||
@@ -4,11 +4,11 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
@@ -42,8 +42,11 @@ type UserRole string
|
||||
|
||||
// User represents a user of the system
|
||||
type User struct {
|
||||
Id string
|
||||
Role UserRole
|
||||
Id string
|
||||
Role UserRole
|
||||
IsServiceUser bool
|
||||
// ServiceUserName is only set if IsServiceUser is true
|
||||
ServiceUserName string
|
||||
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
|
||||
AutoGroups []string
|
||||
PATs map[string]*PersonalAccessToken
|
||||
@@ -63,12 +66,13 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
|
||||
if userData == nil {
|
||||
return &UserInfo{
|
||||
ID: u.Id,
|
||||
Email: "",
|
||||
Name: "",
|
||||
Role: string(u.Role),
|
||||
AutoGroups: u.AutoGroups,
|
||||
Status: string(UserStatusActive),
|
||||
ID: u.Id,
|
||||
Email: "",
|
||||
Name: u.ServiceUserName,
|
||||
Role: string(u.Role),
|
||||
AutoGroups: u.AutoGroups,
|
||||
Status: string(UserStatusActive),
|
||||
IsServiceUser: u.IsServiceUser,
|
||||
}, nil
|
||||
}
|
||||
if userData.ID != u.Id {
|
||||
@@ -81,12 +85,13 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
}
|
||||
|
||||
return &UserInfo{
|
||||
ID: u.Id,
|
||||
Email: userData.Email,
|
||||
Name: userData.Name,
|
||||
Role: string(u.Role),
|
||||
AutoGroups: autoGroups,
|
||||
Status: string(userStatus),
|
||||
ID: u.Id,
|
||||
Email: userData.Email,
|
||||
Name: userData.Name,
|
||||
Role: string(u.Role),
|
||||
AutoGroups: autoGroups,
|
||||
Status: string(userStatus),
|
||||
IsServiceUser: u.IsServiceUser,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -101,34 +106,88 @@ func (u *User) Copy() *User {
|
||||
pats[k] = patCopy
|
||||
}
|
||||
return &User{
|
||||
Id: u.Id,
|
||||
Role: u.Role,
|
||||
AutoGroups: autoGroups,
|
||||
PATs: pats,
|
||||
Id: u.Id,
|
||||
Role: u.Role,
|
||||
AutoGroups: autoGroups,
|
||||
IsServiceUser: u.IsServiceUser,
|
||||
ServiceUserName: u.ServiceUserName,
|
||||
PATs: pats,
|
||||
}
|
||||
}
|
||||
|
||||
// NewUser creates a new user
|
||||
func NewUser(id string, role UserRole) *User {
|
||||
func NewUser(id string, role UserRole, isServiceUser bool, serviceUserName string, autoGroups []string) *User {
|
||||
return &User{
|
||||
Id: id,
|
||||
Role: role,
|
||||
AutoGroups: []string{},
|
||||
Id: id,
|
||||
Role: role,
|
||||
IsServiceUser: isServiceUser,
|
||||
ServiceUserName: serviceUserName,
|
||||
AutoGroups: autoGroups,
|
||||
}
|
||||
}
|
||||
|
||||
// NewRegularUser creates a new user with role UserRoleAdmin
|
||||
// NewRegularUser creates a new user with role UserRoleUser
|
||||
func NewRegularUser(id string) *User {
|
||||
return NewUser(id, UserRoleUser)
|
||||
return NewUser(id, UserRoleUser, false, "", []string{})
|
||||
}
|
||||
|
||||
// NewAdminUser creates a new user with role UserRoleAdmin
|
||||
func NewAdminUser(id string) *User {
|
||||
return NewUser(id, UserRoleAdmin)
|
||||
return NewUser(id, UserRoleAdmin, false, "", []string{})
|
||||
}
|
||||
|
||||
// createServiceUser creates a new service user under the given account.
|
||||
func (am *DefaultAccountManager) createServiceUser(accountID string, executingUserID string, role UserRole, serviceUserName string, autoGroups []string) (*UserInfo, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID)
|
||||
}
|
||||
|
||||
executingUser := account.Users[executingUserID]
|
||||
if executingUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
if executingUser.Role != UserRoleAdmin {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only admins can create service users")
|
||||
}
|
||||
|
||||
newUserID := uuid.New().String()
|
||||
newUser := NewUser(newUserID, role, true, serviceUserName, autoGroups)
|
||||
log.Debugf("New User: %v", newUser)
|
||||
account.Users[newUserID] = newUser
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
meta := map[string]any{"name": newUser.ServiceUserName}
|
||||
am.storeEvent(executingUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta)
|
||||
|
||||
return &UserInfo{
|
||||
ID: newUser.Id,
|
||||
Email: "",
|
||||
Name: newUser.ServiceUserName,
|
||||
Role: string(newUser.Role),
|
||||
AutoGroups: newUser.AutoGroups,
|
||||
Status: string(UserStatusActive),
|
||||
IsServiceUser: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user under the given account. Effectively this is a user invite.
|
||||
func (am *DefaultAccountManager) CreateUser(accountID, userID string, invite *UserInfo) (*UserInfo, error) {
|
||||
func (am *DefaultAccountManager) CreateUser(accountID, userID string, user *UserInfo) (*UserInfo, error) {
|
||||
if user.IsServiceUser {
|
||||
return am.createServiceUser(accountID, userID, StrRoleToUserRole(user.Role), user.Name, user.AutoGroups)
|
||||
}
|
||||
return am.inviteNewUser(accountID, userID, user)
|
||||
}
|
||||
|
||||
// inviteNewUser Invites a USer to a given account and creates reference in datastore
|
||||
func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite *UserInfo) (*UserInfo, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
@@ -193,8 +252,48 @@ func (am *DefaultAccountManager) CreateUser(accountID, userID string, invite *Us
|
||||
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user from the given account.
|
||||
func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, targetUserID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
targetUser := account.Users[targetUserID]
|
||||
if targetUser == nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
executingUser := account.Users[executingUserID]
|
||||
if executingUser == nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
if executingUser.Role != UserRoleAdmin {
|
||||
return status.Errorf(status.PermissionDenied, "only admins can delete service users")
|
||||
}
|
||||
|
||||
if !targetUser.IsServiceUser {
|
||||
return status.Errorf(status.PermissionDenied, "regular users can not be deleted")
|
||||
}
|
||||
|
||||
meta := map[string]any{"name": targetUser.ServiceUserName}
|
||||
am.storeEvent(executingUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta)
|
||||
|
||||
delete(account.Users, targetUserID)
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreatePAT creates a new PAT for the given user
|
||||
func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
|
||||
func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
@@ -206,21 +305,26 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str
|
||||
return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365")
|
||||
}
|
||||
|
||||
if executingUserID != targetUserId {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targetUser := account.Users[targetUserId]
|
||||
targetUser := account.Users[targetUserID]
|
||||
if targetUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "targetUser not found")
|
||||
}
|
||||
|
||||
pat, err := CreateNewPAT(tokenName, expiresIn, targetUser.Id)
|
||||
executingUser := account.Users[executingUserID]
|
||||
if targetUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
|
||||
}
|
||||
|
||||
pat, err := CreateNewPAT(tokenName, expiresIn, executingUser.Id)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err)
|
||||
}
|
||||
@@ -232,8 +336,8 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str
|
||||
return nil, status.Errorf(status.Internal, "failed to save account: %v", err)
|
||||
}
|
||||
|
||||
meta := map[string]any{"name": pat.Name}
|
||||
am.storeEvent(executingUserID, targetUserId, accountID, activity.PersonalAccessTokenCreated, meta)
|
||||
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
|
||||
am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenCreated, meta)
|
||||
|
||||
return pat, nil
|
||||
}
|
||||
@@ -243,21 +347,26 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
if executingUserID != targetUserID {
|
||||
return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.NotFound, "account not found: %s", err)
|
||||
}
|
||||
|
||||
user := account.Users[targetUserID]
|
||||
if user == nil {
|
||||
targetUser := account.Users[targetUserID]
|
||||
if targetUser == nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
pat := user.PATs[tokenID]
|
||||
executingUser := account.Users[executingUserID]
|
||||
if targetUser == nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||
return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user")
|
||||
}
|
||||
|
||||
pat := targetUser.PATs[tokenID]
|
||||
if pat == nil {
|
||||
return status.Errorf(status.NotFound, "PAT not found")
|
||||
}
|
||||
@@ -271,10 +380,10 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
|
||||
return status.Errorf(status.Internal, "Failed to delete hashed token index: %s", err)
|
||||
}
|
||||
|
||||
meta := map[string]any{"name": pat.Name}
|
||||
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
|
||||
am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta)
|
||||
|
||||
delete(user.PATs, tokenID)
|
||||
delete(targetUser.PATs, tokenID)
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
@@ -288,21 +397,26 @@ func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
if executingUserID != targetUserID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: %s", err)
|
||||
}
|
||||
|
||||
user := account.Users[targetUserID]
|
||||
if user == nil {
|
||||
targetUser := account.Users[targetUserID]
|
||||
if targetUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
pat := user.PATs[tokenID]
|
||||
executingUser := account.Users[executingUserID]
|
||||
if targetUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this userser")
|
||||
}
|
||||
|
||||
pat := targetUser.PATs[tokenID]
|
||||
if pat == nil {
|
||||
return nil, status.Errorf(status.NotFound, "PAT not found")
|
||||
}
|
||||
@@ -315,22 +429,27 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID st
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
if executingUserID != targetUserID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: %s", err)
|
||||
}
|
||||
|
||||
user := account.Users[targetUserID]
|
||||
if user == nil {
|
||||
targetUser := account.Users[targetUserID]
|
||||
if targetUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
executingUser := account.Users[executingUserID]
|
||||
if targetUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
||||
}
|
||||
|
||||
var pats []*PersonalAccessToken
|
||||
for _, pat := range user.PATs {
|
||||
for _, pat := range targetUser.PATs {
|
||||
pats = append(pats, pat)
|
||||
}
|
||||
|
||||
@@ -386,7 +505,7 @@ func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
am.storeEvent(userID, oldUser.Id, accountID, activity.GroupRemovedFromUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID})
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
} else {
|
||||
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
||||
}
|
||||
@@ -397,14 +516,14 @@ func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
am.storeEvent(userID, oldUser.Id, accountID, activity.GroupAddedToUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID})
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
} else {
|
||||
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if !isNil(am.idpManager) {
|
||||
if !isNil(am.idpManager) && !newUser.IsServiceUser {
|
||||
userData, err := am.lookupUserInCache(newUser.Id, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -454,14 +573,19 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string)
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// IsUserAdmin flag for current user authenticated by JWT token
|
||||
func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) {
|
||||
account, _, err := am.GetAccountFromToken(claims)
|
||||
// GetAccountByUserID returns an existing account for a given user id
|
||||
func (am *DefaultAccountManager) GetAccountByUserID(userID string) (*Account, error) {
|
||||
return am.Store.GetAccountByUser(userID)
|
||||
}
|
||||
|
||||
// IsUserAdmin looks up a user by his ID and returns true if he is an admin
|
||||
func (am *DefaultAccountManager) IsUserAdmin(userID string) (bool, error) {
|
||||
account, err := am.GetAccountByUserID(userID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get account: %v", err)
|
||||
}
|
||||
|
||||
user, ok := account.Users[claims.UserId]
|
||||
user, ok := account.Users[userID]
|
||||
if !ok {
|
||||
return false, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
@@ -486,7 +610,9 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
|
||||
if !isNil(am.idpManager) {
|
||||
users := make(map[string]struct{}, len(account.Users))
|
||||
for _, user := range account.Users {
|
||||
users[user.Id] = struct{}{}
|
||||
if !user.IsServiceUser {
|
||||
users[user.Id] = struct{}{}
|
||||
}
|
||||
}
|
||||
queriedUsers, err = am.lookupCache(users, accountID)
|
||||
if err != nil {
|
||||
@@ -512,20 +638,44 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
|
||||
return userInfos, nil
|
||||
}
|
||||
|
||||
for _, queriedUser := range queriedUsers {
|
||||
if !user.IsAdmin() && user.Id != queriedUser.ID {
|
||||
for _, localUser := range account.Users {
|
||||
if !user.IsAdmin() && user.Id != localUser.Id {
|
||||
// if user is not an admin then show only current user and do not show other users
|
||||
continue
|
||||
}
|
||||
if localUser, contains := account.Users[queriedUser.ID]; contains {
|
||||
|
||||
info, err := localUser.toUserInfo(queriedUser)
|
||||
var info *UserInfo
|
||||
if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains {
|
||||
info, err = localUser.toUserInfo(queriedUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userInfos = append(userInfos, info)
|
||||
} else {
|
||||
name := ""
|
||||
if localUser.IsServiceUser {
|
||||
name = localUser.ServiceUserName
|
||||
}
|
||||
info = &UserInfo{
|
||||
ID: localUser.Id,
|
||||
Email: "",
|
||||
Name: name,
|
||||
Role: string(localUser.Role),
|
||||
AutoGroups: localUser.AutoGroups,
|
||||
Status: string(UserStatusActive),
|
||||
IsServiceUser: localUser.IsServiceUser,
|
||||
}
|
||||
}
|
||||
userInfos = append(userInfos, info)
|
||||
}
|
||||
|
||||
return userInfos, nil
|
||||
}
|
||||
|
||||
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
|
||||
for _, user := range userData {
|
||||
if user.ID == userID {
|
||||
return user, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
@@ -1,25 +1,32 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
)
|
||||
|
||||
const (
|
||||
mockAccountID = "accountID"
|
||||
mockUserID = "userID"
|
||||
mockTargetUserId = "targetUserID"
|
||||
mockTokenID1 = "tokenID1"
|
||||
mockToken1 = "SoMeHaShEdToKeN1"
|
||||
mockTokenID2 = "tokenID2"
|
||||
mockToken2 = "SoMeHaShEdToKeN2"
|
||||
mockTokenName = "tokenName"
|
||||
mockEmptyTokenName = ""
|
||||
mockExpiresIn = 7
|
||||
mockWrongExpiresIn = 4506
|
||||
mockAccountID = "accountID"
|
||||
mockUserID = "userID"
|
||||
mockServiceUserID = "serviceUserID"
|
||||
mockRole = "user"
|
||||
mockServiceUserName = "serviceUserName"
|
||||
mockTargetUserId = "targetUserID"
|
||||
mockTokenID1 = "tokenID1"
|
||||
mockToken1 = "SoMeHaShEdToKeN1"
|
||||
mockTokenID2 = "tokenID2"
|
||||
mockToken2 = "SoMeHaShEdToKeN2"
|
||||
mockTokenName = "tokenName"
|
||||
mockEmptyTokenName = ""
|
||||
mockExpiresIn = 7
|
||||
mockWrongExpiresIn = 4506
|
||||
)
|
||||
|
||||
func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
@@ -41,6 +48,8 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
t.Fatalf("Error when adding PAT to user: %s", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, pat.CreatedBy, mockUserID)
|
||||
|
||||
fileStore := am.Store.(*FileStore)
|
||||
tokenID := fileStore.HashedPAT2TokenID[pat.HashedToken]
|
||||
|
||||
@@ -60,7 +69,10 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
account.Users[mockTargetUserId] = &User{
|
||||
Id: mockTargetUserId,
|
||||
IsServiceUser: false,
|
||||
}
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
@@ -75,6 +87,31 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
|
||||
assert.Errorf(t, err, "Creating PAT for different user should thorw error")
|
||||
}
|
||||
|
||||
func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockTargetUserId] = &User{
|
||||
Id: mockTargetUserId,
|
||||
IsServiceUser: true,
|
||||
}
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
pat, err := am.CreatePAT(mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when adding PAT to user: %s", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, pat.CreatedBy, mockUserID)
|
||||
}
|
||||
|
||||
func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
@@ -207,3 +244,300 @@ func TestUser_GetAllPATs(t *testing.T) {
|
||||
|
||||
assert.Equal(t, 2, len(pats))
|
||||
}
|
||||
|
||||
func TestUser_Copy(t *testing.T) {
|
||||
// this is an imaginary case which will never be in DB this way
|
||||
user := User{
|
||||
Id: "userId",
|
||||
Role: "role",
|
||||
IsServiceUser: true,
|
||||
ServiceUserName: "servicename",
|
||||
AutoGroups: []string{"group1", "group2"},
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
"pat1": {
|
||||
ID: "pat1",
|
||||
Name: "First PAT",
|
||||
HashedToken: "SoMeHaShEdToKeN",
|
||||
ExpirationDate: time.Now().AddDate(0, 0, 7),
|
||||
CreatedBy: "userId",
|
||||
CreatedAt: time.Now(),
|
||||
LastUsed: time.Now(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := validateStruct(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Test needs update: dummy struct has not all fields set : %s", err)
|
||||
}
|
||||
|
||||
copiedUser := user.Copy()
|
||||
|
||||
assert.True(t, cmp.Equal(user, *copiedUser))
|
||||
}
|
||||
|
||||
// based on https://medium.com/@anajankow/fast-check-if-all-struct-fields-are-set-in-golang-bba1917213d2
|
||||
func validateStruct(s interface{}) (err error) {
|
||||
|
||||
structType := reflect.TypeOf(s)
|
||||
structVal := reflect.ValueOf(s)
|
||||
fieldNum := structVal.NumField()
|
||||
|
||||
for i := 0; i < fieldNum; i++ {
|
||||
field := structVal.Field(i)
|
||||
fieldName := structType.Field(i).Name
|
||||
|
||||
isSet := field.IsValid() && !field.IsZero()
|
||||
|
||||
if !isSet {
|
||||
err = fmt.Errorf("%v%s in not set; ", err, fieldName)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func TestUser_CreateServiceUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
user, err := am.createServiceUser(mockAccountID, mockUserID, mockRole, mockServiceUserName, []string{"group1", "group2"})
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating service user: %s", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 2, len(store.Accounts[mockAccountID].Users))
|
||||
assert.NotNil(t, store.Accounts[mockAccountID].Users[user.ID])
|
||||
assert.True(t, store.Accounts[mockAccountID].Users[user.ID].IsServiceUser)
|
||||
assert.Equal(t, mockServiceUserName, store.Accounts[mockAccountID].Users[user.ID].ServiceUserName)
|
||||
assert.Equal(t, UserRole(mockRole), store.Accounts[mockAccountID].Users[user.ID].Role)
|
||||
assert.Equal(t, []string{"group1", "group2"}, store.Accounts[mockAccountID].Users[user.ID].AutoGroups)
|
||||
assert.Equal(t, map[string]*PersonalAccessToken{}, store.Accounts[mockAccountID].Users[user.ID].PATs)
|
||||
|
||||
assert.Zero(t, user.Email)
|
||||
assert.True(t, user.IsServiceUser)
|
||||
assert.Equal(t, "active", user.Status)
|
||||
}
|
||||
|
||||
func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
user, err := am.CreateUser(mockAccountID, mockUserID, &UserInfo{
|
||||
Name: mockServiceUserName,
|
||||
Role: mockRole,
|
||||
IsServiceUser: true,
|
||||
AutoGroups: []string{"group1", "group2"},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating user: %s", err)
|
||||
}
|
||||
|
||||
assert.True(t, user.IsServiceUser)
|
||||
assert.Equal(t, 2, len(store.Accounts[mockAccountID].Users))
|
||||
assert.True(t, store.Accounts[mockAccountID].Users[user.ID].IsServiceUser)
|
||||
assert.Equal(t, mockServiceUserName, store.Accounts[mockAccountID].Users[user.ID].ServiceUserName)
|
||||
assert.Equal(t, UserRole(mockRole), store.Accounts[mockAccountID].Users[user.ID].Role)
|
||||
assert.Equal(t, []string{"group1", "group2"}, store.Accounts[mockAccountID].Users[user.ID].AutoGroups)
|
||||
|
||||
assert.Equal(t, mockServiceUserName, user.Name)
|
||||
assert.Equal(t, mockRole, user.Role)
|
||||
assert.Equal(t, []string{"group1", "group2"}, user.AutoGroups)
|
||||
assert.Equal(t, "active", user.Status)
|
||||
}
|
||||
|
||||
func TestUser_CreateUser_RegularUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
_, err = am.CreateUser(mockAccountID, mockUserID, &UserInfo{
|
||||
Name: mockServiceUserName,
|
||||
Role: mockRole,
|
||||
IsServiceUser: false,
|
||||
AutoGroups: []string{"group1", "group2"},
|
||||
})
|
||||
|
||||
assert.Errorf(t, err, "Not configured IDP will throw error but right path used")
|
||||
}
|
||||
|
||||
func TestUser_DeleteUser_ServiceUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockServiceUserID] = &User{
|
||||
Id: mockServiceUserID,
|
||||
IsServiceUser: true,
|
||||
ServiceUserName: mockServiceUserName,
|
||||
}
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
err = am.DeleteUser(mockAccountID, mockUserID, mockServiceUserID)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when deleting user: %s", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, len(store.Accounts[mockAccountID].Users))
|
||||
assert.Nil(t, store.Accounts[mockAccountID].Users[mockServiceUserID])
|
||||
}
|
||||
|
||||
func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
err = am.DeleteUser(mockAccountID, mockUserID, mockUserID)
|
||||
|
||||
assert.Errorf(t, err, "Regular users can not be deleted (yet)")
|
||||
}
|
||||
|
||||
func TestUser_IsUserAdmin_ForAdmin(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
ok, err := am.IsUserAdmin(mockUserID)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when checking user role: %s", err)
|
||||
}
|
||||
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func TestUser_IsUserAdmin_ForUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockUserID] = &User{
|
||||
Id: mockUserID,
|
||||
Role: "user",
|
||||
}
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
ok, err := am.IsUserAdmin(mockUserID)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when checking user role: %s", err)
|
||||
}
|
||||
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockServiceUserID] = &User{
|
||||
Id: mockServiceUserID,
|
||||
Role: "user",
|
||||
IsServiceUser: true,
|
||||
}
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
users, err := am.GetUsersFromAccount(mockAccountID, mockUserID)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting users from account: %s", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 2, len(users))
|
||||
}
|
||||
|
||||
func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockServiceUserID] = &User{
|
||||
Id: mockServiceUserID,
|
||||
Role: "user",
|
||||
IsServiceUser: true,
|
||||
}
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
users, err := am.GetUsersFromAccount(mockAccountID, mockServiceUserID)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting users from account: %s", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, len(users))
|
||||
assert.Equal(t, mockServiceUserID, users[0].ID)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user