Compare commits

..

2 Commits

Author SHA1 Message Date
Misha Bragin
fa0399d975 Add more interfaces to ignore (#351) 2022-06-04 20:15:41 +02:00
Misha Bragin
e6e9f0322f Handle peer interface config change (#348)
Before this change, NetBird Agent wasn't handling
peer interface configuration changes dynamically.
Also, remote peer configuration changes have
not been applied (e.g. AllowedIPs changed).
Not a very common cause, but still it should be handled.
Now, Agent reacts to PeerConfig changes sent from the
management service and restarts remote connections
if AllowedIps have been changed.
2022-06-04 19:41:01 +02:00
19 changed files with 354 additions and 388 deletions

View File

@@ -1,4 +1,4 @@
FROM gcr.io/distroless/base:debug
ENV WT_LOG_FILE=console
ENTRYPOINT [ "/go/bin/netbird","up"]
COPY netbird /go/bin/netbird
COPY netbird /go/bin/netbird

View File

@@ -58,7 +58,8 @@ func createNewConfig(managementURL, adminURL, configPath, preSharedKey string) (
config.PreSharedKey = preSharedKey
}
config.IFaceBlackList = []string{iface.WgInterfaceDefault, "tun0"}
config.IFaceBlackList = []string{iface.WgInterfaceDefault, "tun0", "zt", "ZeroTier", "utun", "wg", "ts",
"Tailscale", "tailscale"}
err := util.WriteJson(configPath, config)
if err != nil {

View File

@@ -78,7 +78,7 @@ type Engine struct {
ctx context.Context
wgInterface iface.WGIface
wgInterface *iface.WGIface
udpMux ice.UDPMux
udpMuxSrflx ice.UniversalUDPMux
@@ -177,7 +177,7 @@ func (e *Engine) Start() error {
myPrivateKey := e.config.WgPrivateKey
var err error
e.wgInterface, err = iface.NewWGIface(wgIfaceName, wgAddr, iface.DefaultMTU)
e.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
if err != nil {
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIfaceName, err.Error())
return err
@@ -216,7 +216,39 @@ func (e *Engine) Start() error {
return nil
}
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service
// modifyPeers updates peers that have been modified (e.g. IP address has been changed).
// It closes the existing connection, removes it from the peerConns map, and creates a new one.
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
// first, check if peers have been modified
var modified []*mgmProto.RemotePeerConfig
for _, p := range peersUpdate {
if peerConn, ok := e.peerConns[p.GetWgPubKey()]; ok {
if peerConn.GetConf().ProxyConfig.AllowedIps != strings.Join(p.AllowedIps, ",") {
modified = append(modified, p)
}
}
}
// second, close all modified connections and remove them from the state map
for _, p := range modified {
err := e.removePeer(p.GetWgPubKey())
if err != nil {
return err
}
}
// third, add the peer connections again
for _, p := range modified {
err := e.addNewPeer(p)
if err != nil {
return err
}
}
return nil
}
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service.
// It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method.
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
currentPeers := make([]string, 0, len(e.peerConns))
for p := range e.peerConns {
@@ -366,6 +398,12 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
}
if update.GetNetworkMap() != nil {
if update.GetNetworkMap().GetPeerConfig() != nil {
err := e.updateConfig(update.GetNetworkMap().GetPeerConfig())
if err != nil {
return err
}
}
// only apply new changes and ignore old ones
err := e.updateNetworkMap(update.GetNetworkMap())
if err != nil {
@@ -376,6 +414,20 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return nil
}
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
if e.wgInterface.Address.String() != conf.Address {
oldAddr := e.wgInterface.Address.String()
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
err := e.wgInterface.UpdateAddr(conf.Address)
if err != nil {
return err
}
log.Infof("updated peer address from %s to %s", oldAddr, conf.Address)
}
return nil
}
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
// E.g. when a new peer has been registered and we are allowed to connect to it.
func (e *Engine) receiveManagementEvents() {
@@ -454,6 +506,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return err
}
err = e.modifyPeers(networkMap.GetRemotePeers())
if err != nil {
return err
}
err = e.addNewPeers(networkMap.GetRemotePeers())
if err != nil {
return err
@@ -464,21 +521,29 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return nil
}
// addNewPeers finds and adds peers that were not know before but arrived from the Management service with the update
// addNewPeers adds peers that were not know before but arrived from the Management service with the update
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
for _, p := range peersUpdate {
peerKey := p.GetWgPubKey()
peerIPs := p.GetAllowedIps()
if _, ok := e.peerConns[peerKey]; !ok {
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
if err != nil {
return err
}
e.peerConns[peerKey] = conn
go e.connWorker(conn, peerKey)
err := e.addNewPeer(p)
if err != nil {
return err
}
}
return nil
}
// addNewPeer add peer if connection doesn't exist
func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
peerKey := peerConfig.GetWgPubKey()
peerIPs := peerConfig.GetAllowedIps()
if _, ok := e.peerConns[peerKey]; !ok {
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
if err != nil {
return err
}
e.peerConns[peerKey] = conn
go e.connWorker(conn, peerKey)
}
return nil
}
@@ -505,6 +570,12 @@ func (e Engine) connWorker(conn *peer.Conn, peerKey string) {
err := conn.Open()
if err != nil {
log.Debugf("connection to peer %s failed: %v", peerKey, err)
switch err.(type) {
case *peer.ConnectionClosedError:
// conn has been forced to close, so we exit the loop
return
default:
}
}
}
}

View File

@@ -7,6 +7,7 @@ import (
"os"
"path/filepath"
"runtime"
"strings"
"sync"
"testing"
"time"
@@ -62,7 +63,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
networkMap *mgmtProto.NetworkMap
expectedLen int
expectedPeers []string
expectedPeers []*mgmtProto.RemotePeerConfig
expectedSerial uint64
}
@@ -81,6 +82,11 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
AllowedIps: []string{"100.64.0.12/24"},
}
modifiedPeer3 := &mgmtProto.RemotePeerConfig{
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
AllowedIps: []string{"100.64.0.20/24"},
}
case1 := testCase{
name: "input with a new peer to add",
networkMap: &mgmtProto.NetworkMap{
@@ -92,7 +98,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
RemotePeersIsEmpty: false,
},
expectedLen: 1,
expectedPeers: []string{peer1.GetWgPubKey()},
expectedPeers: []*mgmtProto.RemotePeerConfig{peer1},
expectedSerial: 1,
}
@@ -108,7 +114,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
RemotePeersIsEmpty: false,
},
expectedLen: 2,
expectedPeers: []string{peer1.GetWgPubKey(), peer2.GetWgPubKey()},
expectedPeers: []*mgmtProto.RemotePeerConfig{peer1, peer2},
expectedSerial: 2,
}
@@ -123,7 +129,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
RemotePeersIsEmpty: false,
},
expectedLen: 2,
expectedPeers: []string{peer1.GetWgPubKey(), peer2.GetWgPubKey()},
expectedPeers: []*mgmtProto.RemotePeerConfig{peer1, peer2},
expectedSerial: 2,
}
@@ -138,11 +144,26 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
RemotePeersIsEmpty: false,
},
expectedLen: 2,
expectedPeers: []string{peer2.GetWgPubKey(), peer3.GetWgPubKey()},
expectedPeers: []*mgmtProto.RemotePeerConfig{peer2, peer3},
expectedSerial: 4,
}
case5 := testCase{
name: "input with one peer to modify",
networkMap: &mgmtProto.NetworkMap{
Serial: 4,
PeerConfig: nil,
RemotePeers: []*mgmtProto.RemotePeerConfig{
modifiedPeer3, peer2,
},
RemotePeersIsEmpty: false,
},
expectedLen: 2,
expectedPeers: []*mgmtProto.RemotePeerConfig{peer2, modifiedPeer3},
expectedSerial: 4,
}
case6 := testCase{
name: "input with all peers to remove",
networkMap: &mgmtProto.NetworkMap{
Serial: 5,
@@ -155,7 +176,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
expectedSerial: 5,
}
for _, c := range []testCase{case1, case2, case3, case4, case5} {
for _, c := range []testCase{case1, case2, case3, case4, case5, case6} {
t.Run(c.name, func(t *testing.T) {
err = engine.updateNetworkMap(c.networkMap)
if err != nil {
@@ -172,9 +193,15 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
}
for _, p := range c.expectedPeers {
if _, ok := engine.peerConns[p]; !ok {
conn, ok := engine.peerConns[p.GetWgPubKey()]
if !ok {
t.Errorf("expecting Engine.peerConns to contain peer %s", p)
}
expectedAllowedIPs := strings.Join(p.AllowedIps, ",")
if conn.GetConf().ProxyConfig.AllowedIps != expectedAllowedIPs {
t.Errorf("expecting peer %s to have AllowedIPs= %s, got %s", p.GetWgPubKey(),
expectedAllowedIPs, conn.GetConf().ProxyConfig.AllowedIps)
}
}
})
}

View File

@@ -5,6 +5,7 @@ import (
"github.com/netbirdio/netbird/iface"
"golang.zx2c4.com/wireguard/wgctrl"
"net"
"strings"
"sync"
"time"
@@ -66,6 +67,11 @@ type Conn struct {
proxy proxy.Proxy
}
// GetConf returns the connection config
func (conn *Conn) GetConf() ConnConfig {
return conn.config
}
// NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open
func NewConn(config ConnConfig) (*Conn, error) {
@@ -79,27 +85,27 @@ func NewConn(config ConnConfig) (*Conn, error) {
}, nil
}
// interfaceFilter is a function passed to ICE Agent to filter out blacklisted interfaces
// interfaceFilter is a function passed to ICE Agent to filter out not allowed interfaces
// to avoid building tunnel over them
func interfaceFilter(blackList []string) func(string) bool {
var blackListMap map[string]struct{}
if blackList != nil {
blackListMap = make(map[string]struct{})
for _, s := range blackList {
blackListMap[s] = struct{}{}
}
}
return func(iFace string) bool {
_, ok := blackListMap[iFace]
if ok {
return false
return func(iFace string) bool {
for _, s := range blackList {
if strings.HasPrefix(iFace, s) {
return false
}
}
// look for unlisted Wireguard interfaces
// look for unlisted WireGuard interfaces
wg, err := wgctrl.New()
if err != nil {
log.Debugf("trying to create a wgctrl client failed with: %v", err)
}
defer wg.Close()
defer func() {
err := wg.Close()
if err != nil {
return
}
}()
_, err = wg.Device(iFace)
return err != nil
@@ -437,7 +443,7 @@ func (conn *Conn) Close() error {
// before conn.Open() another update from management arrives with peers: [1,2,3,4,5]
// engine adds a new Conn for 4 and 5
// therefore peer 4 has 2 Conn objects
log.Warnf("closing not started coonection %s", conn.config.Key)
log.Warnf("connection has been already closed or attempted closing not started coonection %s", conn.config.Key)
return NewConnectionAlreadyClosed(conn.config.Key)
}
}

View File

@@ -3,6 +3,7 @@ package peer
import (
"github.com/magiconair/properties/assert"
"github.com/netbirdio/netbird/client/internal/proxy"
"github.com/netbirdio/netbird/iface"
"github.com/pion/ice/v2"
"sync"
"testing"
@@ -18,6 +19,18 @@ var connConf = ConnConfig{
ProxyConfig: proxy.Config{},
}
func TestNewConn_interfaceFilter(t *testing.T) {
ignore := []string{iface.WgInterfaceDefault, "tun0", "zt", "ZeroTier", "utun", "wg", "ts",
"Tailscale", "tailscale"}
filter := interfaceFilter(ignore)
for _, s := range ignore {
assert.Equal(t, filter(s), false)
}
}
func TestConn_GetKey(t *testing.T) {
conn, err := NewConn(connConf)
if err != nil {

View File

@@ -21,7 +21,7 @@ const (
type Config struct {
WgListenAddr string
RemoteKey string
WgInterface iface.WGIface
WgInterface *iface.WGIface
AllowedIps string
PreSharedKey *wgtypes.Key
}

View File

@@ -30,6 +30,8 @@ func (w *WGIface) configureDevice(config wgtypes.Config) error {
// Configure configures a Wireguard interface
// The interface must exist before calling this method (e.g. call interface.Create() before)
func (w *WGIface) Configure(privateKey string, port int) error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("configuring Wireguard interface %s", w.Name)
@@ -76,6 +78,8 @@ func (w *WGIface) GetListenPort() (*int, error) {
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
// Endpoint is optional
func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("updating interface %s peer %s: endpoint %s ", w.Name, peerKey, endpoint)
@@ -110,6 +114,9 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.D
// RemovePeer removes a Wireguard Peer from the interface iface
func (w *WGIface) RemovePeer(peerKey string) error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("Removing peer %s from interface %s ", peerKey, w.Name)
peerKeyParsed, err := wgtypes.ParseKey(peerKey)

View File

@@ -1,10 +1,11 @@
package iface
import (
"golang.zx2c4.com/wireguard/wgctrl"
"fmt"
"net"
"os"
"runtime"
"sync"
)
const (
@@ -19,6 +20,7 @@ type WGIface struct {
MTU int
Address WGAddress
Interface NetInterface
mu sync.Mutex
}
// WGAddress Wireguard parsed address
@@ -27,16 +29,22 @@ type WGAddress struct {
Network *net.IPNet
}
func (addr *WGAddress) String() string {
maskSize, _ := addr.Network.Mask.Size()
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
}
// NetInterface represents a generic network tunnel interface
type NetInterface interface {
Close() error
}
// NewWGIface Creates a new Wireguard interface instance
func NewWGIface(iface string, address string, mtu int) (WGIface, error) {
wgIface := WGIface{
// NewWGIFace Creates a new Wireguard interface instance
func NewWGIFace(iface string, address string, mtu int) (*WGIface, error) {
wgIface := &WGIface{
Name: iface,
MTU: mtu,
mu: sync.Mutex{},
}
wgAddress, err := parseAddress(address)
@@ -49,30 +57,6 @@ func NewWGIface(iface string, address string, mtu int) (WGIface, error) {
return wgIface, nil
}
// Exists checks whether specified Wireguard device exists or not
func Exists(iface string) (*bool, error) {
wg, err := wgctrl.New()
if err != nil {
return nil, err
}
defer wg.Close()
devices, err := wg.Devices()
if err != nil {
return nil, err
}
var exists bool
for _, d := range devices {
if d.Name == iface {
exists = true
return &exists, nil
}
}
exists = false
return &exists, nil
}
// parseAddress parse a string ("1.2.3.4/24") address to WG Address
func parseAddress(address string) (WGAddress, error) {
ip, network, err := net.ParseCIDR(address)
@@ -85,8 +69,10 @@ func parseAddress(address string) (WGAddress, error) {
}, nil
}
// Closes the tunnel interface
// Close closes the tunnel interface
func (w *WGIface) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
err := w.Interface.Close()
if err != nil {

View File

@@ -7,7 +7,10 @@ import (
// Create Creates a new Wireguard interface, sets a given IP and brings it up.
func (w *WGIface) Create() error {
return w.CreateWithUserspace()
w.mu.Lock()
defer w.mu.Unlock()
return w.createWithUserspace()
}
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided

View File

@@ -2,7 +2,6 @@ package iface
import (
"errors"
"fmt"
"math"
"os"
"syscall"
@@ -33,22 +32,24 @@ func WireguardModExists() bool {
return errors.Is(err, syscall.EINVAL)
}
// Create Creates a new Wireguard interface, sets a given IP and brings it up.
// Create creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
func (w *WGIface) Create() error {
w.mu.Lock()
defer w.mu.Unlock()
if WireguardModExists() {
log.Info("using kernel WireGuard")
return w.CreateWithKernel()
return w.createWithKernel()
} else {
log.Info("using userspace WireGuard")
return w.CreateWithUserspace()
return w.createWithUserspace()
}
}
// CreateWithKernel Creates a new Wireguard interface using kernel Wireguard module.
// createWithKernel Creates a new Wireguard interface using kernel Wireguard module.
// Works for Linux and offers much better network performance
func (w *WGIface) CreateWithKernel() error {
func (w *WGIface) createWithKernel() error {
link := newWGLink(w.Name)
@@ -106,10 +107,6 @@ func (w *WGIface) CreateWithKernel() error {
// assignAddr Adds IP address to the tunnel interface
func (w *WGIface) assignAddr() error {
mask, _ := w.Address.Network.Mask.Size()
address := fmt.Sprintf("%s/%d", w.Address.IP.String(), mask)
link := newWGLink(w.Name)
//delete existing addresses
@@ -126,11 +123,11 @@ func (w *WGIface) assignAddr() error {
}
}
log.Debugf("adding address %s to interface: %s", address, w.Name)
addr, _ := netlink.ParseAddr(address)
log.Debugf("adding address %s to interface: %s", w.Address.String(), w.Name)
addr, _ := netlink.ParseAddr(w.Address.String())
err = netlink.AddrAdd(link, addr)
if os.IsExist(err) {
log.Infof("interface %s already has the address: %s", w.Name, address)
log.Infof("interface %s already has the address: %s", w.Name, w.Address.String())
} else if err != nil {
return err
}

View File

@@ -3,6 +3,7 @@ package iface
import (
"fmt"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"net"
@@ -28,11 +29,71 @@ func init() {
peerPubKey = peerPrivateKey.PublicKey().String()
}
func TestWGIface_UpdateAddr(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
addr := "100.64.0.1/8"
iface, err := NewWGIFace(ifaceName, addr, DefaultMTU)
if err != nil {
t.Fatal(err)
}
err = iface.Create()
if err != nil {
t.Fatal(err)
}
defer func() {
err = iface.Close()
if err != nil {
t.Error(err)
}
}()
port, err := iface.GetListenPort()
if err != nil {
t.Fatal(err)
}
err = iface.Configure(key, *port)
if err != nil {
t.Fatal(err)
}
addrs, err := getIfaceAddrs(ifaceName)
if err != nil {
t.Error(err)
}
assert.Equal(t, addr, addrs[0].String())
//update WireGuard address
addr = "100.64.0.2/8"
err = iface.UpdateAddr(addr)
if err != nil {
t.Fatal(err)
}
addrs, err = getIfaceAddrs(ifaceName)
if err != nil {
t.Error(err)
}
assert.Equal(t, addr, addrs[0].String())
}
func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
ief, err := net.InterfaceByName(ifaceName)
if err != nil {
return nil, err
}
addrs, err := ief.Addrs()
if err != nil {
return nil, err
}
return addrs, nil
}
//
func Test_CreateInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
wgIP := "10.99.99.1/32"
iface, err := NewWGIface(ifaceName, wgIP, DefaultMTU)
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU)
if err != nil {
t.Fatal(err)
}
@@ -61,7 +122,7 @@ func Test_CreateInterface(t *testing.T) {
func Test_Close(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
wgIP := "10.99.99.2/32"
iface, err := NewWGIface(ifaceName, wgIP, DefaultMTU)
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU)
if err != nil {
t.Fatal(err)
}
@@ -89,7 +150,7 @@ func Test_Close(t *testing.T) {
func Test_ConfigureInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
wgIP := "10.99.99.5/30"
iface, err := NewWGIface(ifaceName, wgIP, DefaultMTU)
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU)
if err != nil {
t.Fatal(err)
}
@@ -136,7 +197,7 @@ func Test_ConfigureInterface(t *testing.T) {
func Test_UpdatePeer(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
wgIP := "10.99.99.9/30"
iface, err := NewWGIface(ifaceName, wgIP, DefaultMTU)
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU)
if err != nil {
t.Fatal(err)
}
@@ -195,7 +256,7 @@ func Test_UpdatePeer(t *testing.T) {
func Test_RemovePeer(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
wgIP := "10.99.99.13/30"
iface, err := NewWGIface(ifaceName, wgIP, DefaultMTU)
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU)
if err != nil {
t.Fatal(err)
}
@@ -247,7 +308,7 @@ func Test_ConnectPeers(t *testing.T) {
keepAlive := 1 * time.Second
iface1, err := NewWGIface(peer1ifaceName, peer1wgIP, DefaultMTU)
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU)
if err != nil {
t.Fatal(err)
}
@@ -264,7 +325,7 @@ func Test_ConnectPeers(t *testing.T) {
t.Fatal(err)
}
iface2, err := NewWGIface(peer2ifaceName, peer2wgIP, DefaultMTU)
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU)
if err != nil {
t.Fatal(err)
}

View File

@@ -12,8 +12,8 @@ import (
"net"
)
// CreateWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation
func (w *WGIface) CreateWithUserspace() error {
// createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation
func (w *WGIface) createWithUserspace() error {
tunIface, err := tun.CreateTUN(w.Name, w.MTU)
if err != nil {
@@ -61,3 +61,17 @@ func getUAPI(iface string) (net.Listener, error) {
}
return ipc.UAPIListen(iface, tunSock)
}
// UpdateAddr updates address of the interface
func (w *WGIface) UpdateAddr(newAddr string) error {
w.mu.Lock()
defer w.mu.Unlock()
addr, err := parseAddress(newAddr)
if err != nil {
return err
}
w.Address = addr
return w.assignAddr()
}

View File

@@ -11,6 +11,8 @@ import (
// Create Creates a new Wireguard interface, sets a given IP and brings it up.
func (w *WGIface) Create() error {
w.mu.Lock()
defer w.mu.Unlock()
WintunStaticRequestedGUID, _ := windows.GenerateGUID()
adapter, err := driver.CreateAdapter(w.Name, "WireGuard", &WintunStaticRequestedGUID)
@@ -40,3 +42,18 @@ func (w *WGIface) assignAddr(luid winipcfg.LUID) error {
return nil
}
// UpdateAddr updates address of the interface
func (w *WGIface) UpdateAddr(newAddr string) error {
w.mu.Lock()
defer w.mu.Unlock()
luid := w.Interface.(*driver.Adapter).LUID()
addr, err := parseAddress(newAddr)
if err != nil {
return err
}
w.Address = addr
return w.assignAddr(luid)
}

View File

@@ -328,12 +328,11 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID string) ([]*UserI
queriedUsers := make([]*idp.UserData, 0)
if !isNil(am.idpManager) {
queriedUsers, err = am.idpManager.GetAllUsers(accountID)
queriedUsers, err = am.idpManager.GetBatchedUserData(accountID)
if err != nil {
return nil, err
}
}
// TODO: we need to check whether we need to refresh our cache or not
userInfo := make([]*UserInfo, 0)
@@ -353,6 +352,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID string) ([]*UserI
for _, queriedUser := range queriedUsers {
if localUser, contains := account.Users[queriedUser.ID]; contains {
userInfo = append(userInfo, mergeLocalAndQueryUser(*queriedUser, *localUser))
log.Debugf("Merged userinfo to send back; %v", userInfo)
}
}

View File

@@ -29,9 +29,6 @@ type Server struct {
jwtMiddleware *middleware.JWTMiddleware
}
// AllowedIPsFormat generates Wireguard AllowedIPs format (e.g. 100.30.30.1/32)
const AllowedIPsFormat = "%s/32"
// NewServer creates a new Management server
func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager) (*Server, error) {
key, err := wgtypes.GeneratePrivateKey()
@@ -227,7 +224,7 @@ func (s *Server) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Pe
peersToSend = append(peersToSend, p)
}
}
update := toSyncResponse(s.config, peer, peersToSend, nil, networkMap.Network.CurrentSerial())
update := toSyncResponse(s.config, remotePeer, peersToSend, nil, networkMap.Network.CurrentSerial())
err = s.peersUpdateManager.SendUpdate(remotePeer.Key, &UpdateMessage{Update: update})
if err != nil {
// todo rethink if we should keep this return
@@ -368,7 +365,7 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
func toPeerConfig(peer *Peer) *proto.PeerConfig {
return &proto.PeerConfig{
Address: peer.IP.String() + "/16", // todo make it explicit
Address: fmt.Sprintf("%s/%d", peer.IP.String(), SubnetSize), // take it from the network
}
}
@@ -377,7 +374,7 @@ func toRemotePeerConfig(peers []*Peer) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
remotePeers = append(remotePeers, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)}, // todo /32
AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)},
})
}

View File

@@ -1,9 +1,6 @@
package idp
import (
"bytes"
"compress/gzip"
"context"
"encoding/json"
"fmt"
"io"
@@ -21,11 +18,10 @@ import (
// Auth0Manager auth0 manager client instance
type Auth0Manager struct {
authIssuer string
httpClient ManagerHTTPClient
credentials ManagerCredentials
helper ManagerHelper
cachedUsersByAccountId map[string][]Auth0Profile
authIssuer string
httpClient ManagerHTTPClient
credentials ManagerCredentials
helper ManagerHelper
}
// Auth0ClientConfig auth0 manager client configurations
@@ -55,38 +51,6 @@ type Auth0Credentials struct {
mux sync.Mutex
}
type Auth0Profile struct {
AccountId string `json:"wt_account_id"`
UserID string `json:"user_id"`
Name string `json:"name"`
Email string `json:"email"`
CreatedAt string `json:"created_at"`
LastLogin string `json:"last_login"`
}
type UserExportJobResponse struct {
Type string `json:"type"`
Status string `json:"status"`
ConnectionId string `json:"connection_id"`
Format string `json:"format"`
Limit int `json:"limit"`
Connection string `json:"connection"`
CreatedAt time.Time `json:"created_at"`
Id string `json:"id"`
}
type ExportJobStatusResponse struct {
Type string `json:"type"`
Status string `json:"status"`
ConnectionId string `json:"connection_id"`
Format string `json:"format"`
Limit int `json:"limit"`
Location string `json:"location"`
Connection string `json:"connection"`
CreatedAt time.Time `json:"created_at"`
Id string `json:"id"`
}
// NewAuth0Manager creates a new instance of the Auth0Manager
func NewAuth0Manager(config Auth0ClientConfig) (*Auth0Manager, error) {
@@ -117,13 +81,11 @@ func NewAuth0Manager(config Auth0ClientConfig) (*Auth0Manager, error) {
httpClient: httpClient,
helper: helper,
}
return &Auth0Manager{
authIssuer: config.AuthIssuer,
credentials: credentials,
httpClient: httpClient,
helper: helper,
cachedUsersByAccountId: make(map[string][]Auth0Profile),
authIssuer: config.AuthIssuer,
credentials: credentials,
httpClient: httpClient,
helper: helper,
}, nil
}
@@ -224,198 +186,44 @@ func (c *Auth0Credentials) Authenticate() (JWTToken, error) {
return c.jwtToken, nil
}
// Gets all users from cache, if the cache exists
// Otherwise we will initialize the cache with creating the export job on auth0
func (am *Auth0Manager) GetAllUsers(accountId string) ([]*UserData, error) {
if len(am.cachedUsersByAccountId[accountId]) == 0 {
err := am.createExportUsersJob(accountId)
if err != nil {
log.Debugf("Couldn't cache users; %v", err)
return nil, err
}
func batchRequestUsersUrl(authIssuer, accountId string, page int) (string, url.Values, error) {
u, err := url.Parse(authIssuer + "/api/v2/users")
if err != nil {
return "", nil, err
}
q := u.Query()
q.Set("page", strconv.Itoa(page))
q.Set("search_engine", "v3")
q.Set("q", "app_metadata.wt_account_id:"+accountId)
u.RawQuery = q.Encode()
return u.String(), q, nil
}
func requestByUserIdUrl(authIssuer, userId string) string {
return authIssuer + "/api/v2/users/" + userId
}
// GetBatchedUserData requests users in batches from Auth0
func (am *Auth0Manager) GetBatchedUserData(accountId string) ([]*UserData, error) {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return nil, err
}
var list []*UserData
cachedUsers := am.cachedUsersByAccountId[accountId]
for _, val := range cachedUsers {
list = append(list, &UserData{
Name: val.Name,
Email: val.Email,
ID: val.UserID,
})
}
return list, nil
}
// This creates an export job on auth0 for all users.
func (am *Auth0Manager) createExportUsersJob(accountId string) error {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return err
}
reqURL := am.authIssuer + "/api/v2/jobs/users-exports"
payloadString := fmt.Sprintf("{\"format\": \"json\"," +
"\"fields\": [{\"name\": \"created_at\"}, {\"name\": \"last_login\"},{\"name\": \"user_id\"}, {\"name\": \"email\"}, {\"name\": \"name\"}, {\"name\": \"app_metadata.wt_account_id\", \"export_as\": \"wt_account_id\"}]}")
payload := strings.NewReader(payloadString)
exportJobReq, err := http.NewRequest("POST", reqURL, payload)
if err != nil {
return err
}
exportJobReq.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
exportJobReq.Header.Add("content-type", "application/json")
jobResp, err := am.httpClient.Do(exportJobReq)
if err != nil {
log.Debugf("Couldn't get job response %v", err)
return err
}
defer func() {
err = jobResp.Body.Close()
if err != nil {
log.Errorf("error while closing update user app metadata response body: %v", err)
}
}()
if jobResp.StatusCode != 200 {
return fmt.Errorf("unable to update the appMetadata, statusCode %d", jobResp.StatusCode)
}
var exportJobResp UserExportJobResponse
body, err := ioutil.ReadAll(jobResp.Body)
if err != nil {
log.Debugf("Coudln't read export job response; %v", err)
return err
}
err = am.helper.Unmarshal(body, &exportJobResp)
if err != nil {
log.Debugf("Coudln't unmarshal export job response; %v", err)
return err
}
if exportJobResp.Id == "" {
return fmt.Errorf("couldn't get an batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
}
log.Debugf("batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
ctx, cancel := context.WithTimeout(context.TODO(), 90*time.Second)
defer cancel()
done, downloadLink, err := am.checkExportJobStatus(ctx, exportJobResp.Id)
if err != nil {
log.Debugf("Failed at getting status checks from exportJob; %v", err)
return err
}
if done {
err = am.cacheUsers(downloadLink)
if err != nil {
log.Debugf("Failed to cache users via download link; %v", err)
}
}
return nil
}
// Downloads the users from auth0 and caches it in memory
// Users are only cached if they have an wt_account_id stored in auth0
func (am *Auth0Manager) cacheUsers(location string) error {
body, err := doGetReq(am.httpClient, location, "")
if err != nil {
log.Debugf("Can't download cached users; %v", err)
return err
}
bodyReader := bytes.NewReader(body)
gzipReader, err := gzip.NewReader(bodyReader)
if err != nil {
return err
}
decoder := json.NewDecoder(gzipReader)
for decoder.More() {
profile := Auth0Profile{}
err = decoder.Decode(&profile)
if err != nil {
log.Errorf("Couldn't decode profile; %v", err)
return err
}
if profile.AccountId != "" {
am.cachedUsersByAccountId[profile.AccountId] = append(am.cachedUsersByAccountId[profile.AccountId], profile)
}
}
return nil
}
// This checks the status of the job created at CreateExportUsersJob.
// If the status is "completed", then return the downloadLink
func (am *Auth0Manager) checkExportJobStatus(ctx context.Context, jobId string) (bool, string, error) {
retry := time.NewTicker(time.Second)
for {
select {
case <-ctx.Done():
log.Debugf("Export job status stopped...\n")
return false, "", ctx.Err()
case <-retry.C:
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return false, "", err
}
statusUrl := am.authIssuer + "/api/v2/jobs/" + jobId
body, err := doGetReq(am.httpClient, statusUrl, jwtToken.AccessToken)
if err != nil {
return false, "", err
}
var status ExportJobStatusResponse
err = am.helper.Unmarshal(body, &status)
if err != nil {
return false, "", err
}
log.Debugf("Current export job status is %v", status.Status)
if status.Status != "completed" {
continue
}
return true, status.Location, nil
}
}
}
// Invalidates old cache for Account and re-queries it from auth0
func (am *Auth0Manager) forceUpdateUserCache(accountId string) error {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return err
}
var list []Auth0Profile
// https://auth0.com/docs/manage-users/user-search/retrieve-users-with-get-users-endpoint#limitations
// auth0 limitation of 1000 users via this endpoint
for page := 0; page < 20; page++ {
reqURL, query, err := batchRequestUsersUrl(am.authIssuer, accountId, page)
if err != nil {
return err
return nil, err
}
req, err := http.NewRequest(http.MethodGet, reqURL, strings.NewReader(query.Encode()))
if err != nil {
return err
return nil, err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
@@ -423,42 +231,41 @@ func (am *Auth0Manager) forceUpdateUserCache(accountId string) error {
res, err := am.httpClient.Do(req)
if err != nil {
return err
return nil, err
}
body, err := io.ReadAll(res.Body)
if err != nil {
return err
return nil, err
}
var batch []Auth0Profile
var batch []UserData
err = json.Unmarshal(body, &batch)
if err != nil {
return err
return nil, err
}
log.Debugf("requested batch; %v", batch)
err = res.Body.Close()
if err != nil {
return err
return nil, err
}
if res.StatusCode != 200 {
return fmt.Errorf("unable to request UserData from auth0, statusCode %d", res.StatusCode)
return nil, fmt.Errorf("unable to request UserData from auth0, statusCode %d", res.StatusCode)
}
if len(batch) == 0 {
return nil
return list, nil
}
for user := range batch {
list = append(list, batch[user])
list = append(list, &batch[user])
}
}
am.cachedUsersByAccountId[accountId] = list
return nil
return list, nil
}
// GetUserDataByID requests user data from auth0 via ID
@@ -552,54 +359,3 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userId string, appMetadata AppMeta
return nil
}
func batchRequestUsersUrl(authIssuer, accountId string, page int) (string, url.Values, error) {
u, err := url.Parse(authIssuer + "/api/v2/users")
if err != nil {
return "", nil, err
}
q := u.Query()
q.Set("page", strconv.Itoa(page))
q.Set("search_engine", "v3")
q.Set("q", "app_metadata.wt_account_id:"+accountId)
u.RawQuery = q.Encode()
return u.String(), q, nil
}
func requestByUserIdUrl(authIssuer, userId string) string {
return authIssuer + "/api/v2/users/" + userId
}
// Boilerplate implementation for Get Requests.
func doGetReq(client ManagerHTTPClient, url, accessToken string) ([]byte, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
if accessToken != "" {
req.Header.Add("authorization", "Bearer "+accessToken)
}
res, err := client.Do(req)
if err != nil {
return nil, err
}
defer func() {
err = res.Body.Close()
if err != nil {
log.Errorf("error while closing body for url %s: %v", url, err)
}
}()
if res.StatusCode != 200 {
return nil, fmt.Errorf("unable to get %s, statusCode %d", url, res.StatusCode)
}
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, err
}
return body, nil
}

View File

@@ -11,7 +11,7 @@ import (
type Manager interface {
UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error
GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error)
GetAllUsers(accountId string) ([]*UserData, error)
GetBatchedUserData(accountId string) ([]*UserData, error)
}
// Config an idp configuration struct to be loaded from management server's config file

View File

@@ -11,6 +11,16 @@ import (
"time"
)
const (
// SubnetSize is a size of the subnet of the global network, e.g. 100.77.0.0/16
SubnetSize = 16
// NetSize is a global network size 100.64.0.0/10
NetSize = 10
// AllowedIPsFormat generates Wireguard AllowedIPs format (e.g. 100.64.30.1/32)
AllowedIPsFormat = "%s/32"
)
type NetworkMap struct {
Peers []*Peer
Network *Network
@@ -31,8 +41,8 @@ type Network struct {
// It takes a random /16 subnet from 100.64.0.0/10 (64 different subnets)
func NewNetwork() *Network {
n := iplib.NewNet4(net.ParseIP("100.64.0.0"), 10)
sub, _ := n.Subnet(16)
n := iplib.NewNet4(net.ParseIP("100.64.0.0"), NetSize)
sub, _ := n.Subnet(SubnetSize)
s := rand.NewSource(time.Now().Unix())
r := rand.New(s)