Compare commits

...

4 Commits

Author SHA1 Message Date
Zoltán Papp
d9bcdcf149 Fix lint 2024-05-15 03:26:30 +02:00
Zoltán Papp
d39814f173 Fix lint 2024-05-15 02:26:29 +02:00
Zoltán Papp
4a2429eb1c Fix cleanup and error handling 2024-05-15 01:36:40 +02:00
Zoltán Papp
de2e6557ad Revert context changes in proxy implementations 2024-05-15 00:27:40 +02:00
9 changed files with 54 additions and 76 deletions

View File

@@ -259,7 +259,7 @@ func (e *Engine) Start() error {
} }
e.ctx, e.cancel = context.WithCancel(e.clientCtx) e.ctx, e.cancel = context.WithCancel(e.clientCtx)
e.wgProxyFactory = wgproxy.NewFactory(e.clientCtx, e.config.WgPort) e.wgProxyFactory = wgproxy.NewFactory(e.config.WgPort)
wgIface, err := e.newWgIface() wgIface, err := e.newWgIface()
if err != nil { if err != nil {

View File

@@ -423,7 +423,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
var endpoint net.Addr var endpoint net.Addr
if isRelayCandidate(pair.Local) { if isRelayCandidate(pair.Local) {
log.Debugf("setup relay connection") log.Debugf("setup relay connection")
conn.wgProxy = conn.wgProxyFactory.GetProxy(conn.ctx) conn.wgProxy = conn.wgProxyFactory.GetProxy()
endpoint, err = conn.wgProxy.AddTurnConn(remoteConn) endpoint, err = conn.wgProxy.AddTurnConn(remoteConn)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -1,7 +1,6 @@
package peer package peer
import ( import (
"context"
"sync" "sync"
"testing" "testing"
"time" "time"
@@ -36,7 +35,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
} }
func TestConn_GetKey(t *testing.T) { func TestConn_GetKey(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
@@ -51,7 +50,7 @@ func TestConn_GetKey(t *testing.T) {
} }
func TestConn_OnRemoteOffer(t *testing.T) { func TestConn_OnRemoteOffer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
@@ -88,7 +87,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
} }
func TestConn_OnRemoteAnswer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
@@ -124,7 +123,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestConn_Status(t *testing.T) { func TestConn_Status(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
@@ -154,7 +153,7 @@ func TestConn_Status(t *testing.T) {
} }
func TestConn_Close(t *testing.T) { func TestConn_Close(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()

View File

@@ -1,17 +1,15 @@
package wgproxy package wgproxy
import "context"
type Factory struct { type Factory struct {
wgPort int wgPort int
ebpfProxy Proxy ebpfProxy Proxy
} }
func (w *Factory) GetProxy(ctx context.Context) Proxy { func (w *Factory) GetProxy() Proxy {
if w.ebpfProxy != nil { if w.ebpfProxy != nil {
return w.ebpfProxy return w.ebpfProxy
} }
return NewWGUserSpaceProxy(ctx, w.wgPort) return NewWGUserSpaceProxy(w.wgPort)
} }
func (w *Factory) Free() error { func (w *Factory) Free() error {

View File

@@ -3,15 +3,13 @@
package wgproxy package wgproxy
import ( import (
"context"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
func NewFactory(ctx context.Context, wgPort int) *Factory { func NewFactory(wgPort int) *Factory {
f := &Factory{wgPort: wgPort} f := &Factory{wgPort: wgPort}
ebpfProxy := NewWGEBPFProxy(ctx, wgPort) ebpfProxy := NewWGEBPFProxy(wgPort)
err := ebpfProxy.listen() err := ebpfProxy.listen()
if err != nil { if err != nil {
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err) log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)

View File

@@ -2,8 +2,6 @@
package wgproxy package wgproxy
import "context" func NewFactory(wgPort int) *Factory {
func NewFactory(ctx context.Context, wgPort int) *Factory {
return &Factory{wgPort: wgPort} return &Factory{wgPort: wgPort}
} }

View File

@@ -3,7 +3,6 @@
package wgproxy package wgproxy
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"net" "net"
@@ -23,13 +22,9 @@ import (
// WGEBPFProxy definition for proxy with EBPF support // WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct { type WGEBPFProxy struct {
ebpfManager ebpfMgr.Manager
ctx context.Context
cancel context.CancelFunc
lastUsedPort uint16
localWGListenPort int localWGListenPort int
ebpfManager ebpfMgr.Manager
lastUsedPort uint16
turnConnStore map[uint16]net.Conn turnConnStore map[uint16]net.Conn
turnConnMutex sync.Mutex turnConnMutex sync.Mutex
@@ -39,7 +34,7 @@ type WGEBPFProxy struct {
} }
// NewWGEBPFProxy create new WGEBPFProxy instance // NewWGEBPFProxy create new WGEBPFProxy instance
func NewWGEBPFProxy(ctx context.Context, wgPort int) *WGEBPFProxy { func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
log.Debugf("instantiate ebpf proxy") log.Debugf("instantiate ebpf proxy")
wgProxy := &WGEBPFProxy{ wgProxy := &WGEBPFProxy{
localWGListenPort: wgPort, localWGListenPort: wgPort,
@@ -47,8 +42,6 @@ func NewWGEBPFProxy(ctx context.Context, wgPort int) *WGEBPFProxy {
lastUsedPort: 0, lastUsedPort: 0,
turnConnStore: make(map[uint16]net.Conn), turnConnStore: make(map[uint16]net.Conn),
} }
wgProxy.ctx, wgProxy.cancel = context.WithCancel(ctx)
return wgProxy return wgProxy
} }
@@ -109,7 +102,6 @@ func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
// CloseConn doing nothing because this type of proxy implementation does not store the connection // CloseConn doing nothing because this type of proxy implementation does not store the connection
func (p *WGEBPFProxy) CloseConn() error { func (p *WGEBPFProxy) CloseConn() error {
p.cancel()
return nil return nil
} }
@@ -138,28 +130,26 @@ func (p *WGEBPFProxy) Free() error {
} }
func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) { func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
buf := make([]byte, 1500)
var err error
defer func() { defer func() {
log.Tracef("stop proxying turn traffic to wg: %d", endpointPort)
p.removeTurnConn(endpointPort) p.removeTurnConn(endpointPort)
}() }()
buf := make([]byte, 1500)
for { for {
select { n, err := remoteConn.Read(buf)
case <-p.ctx.Done(): if err != nil {
if err != io.EOF {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
return return
default: }
var n int err = p.sendPkg(buf[:n], endpointPort)
n, err = remoteConn.Read(buf) if err != nil {
if err != nil { if err == io.EOF {
if err != io.EOF {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
return return
} }
err = p.sendPkg(buf[:n], endpointPort) log.Errorf("failed to write out turn pkg to local conn: %v", err)
if err != nil {
log.Errorf("failed to write out turn pkg to local conn: %v", err)
}
} }
} }
} }
@@ -168,28 +158,23 @@ func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
func (p *WGEBPFProxy) proxyToRemote() { func (p *WGEBPFProxy) proxyToRemote() {
buf := make([]byte, 1500) buf := make([]byte, 1500)
for { for {
select { n, addr, err := p.conn.ReadFromUDP(buf)
case <-p.ctx.Done(): if err != nil {
log.Errorf("failed to read UDP pkg from WG: %s", err)
return return
default: }
n, addr, err := p.conn.ReadFromUDP(buf)
if err != nil {
log.Errorf("failed to read UDP pkg from WG: %s", err)
return
}
p.turnConnMutex.Lock() p.turnConnMutex.Lock()
conn, ok := p.turnConnStore[uint16(addr.Port)] conn, ok := p.turnConnStore[uint16(addr.Port)]
p.turnConnMutex.Unlock() p.turnConnMutex.Unlock()
if !ok { if !ok {
log.Infof("turn conn not found by port: %d", addr.Port) log.Infof("turn conn not found by port: %d", addr.Port)
continue continue
} }
_, err = conn.Write(buf[:n]) _, err = conn.Write(buf[:n])
if err != nil { if err != nil {
log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err) log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err)
}
} }
} }
} }
@@ -207,11 +192,9 @@ func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
} }
func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) { func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) {
log.Tracef("remove turn conn from store by port: %d", turnConnID)
p.turnConnMutex.Lock() p.turnConnMutex.Lock()
defer p.turnConnMutex.Unlock() defer p.turnConnMutex.Unlock()
delete(p.turnConnStore, turnConnID) delete(p.turnConnStore, turnConnID)
} }
func (p *WGEBPFProxy) nextFreePort() (uint16, error) { func (p *WGEBPFProxy) nextFreePort() (uint16, error) {
@@ -287,17 +270,20 @@ func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error {
err := udpH.SetNetworkLayerForChecksum(ipH) err := udpH.SetNetworkLayerForChecksum(ipH)
if err != nil { if err != nil {
return fmt.Errorf("set network layer for checksum: %w", err) log.Errorf("set network layer for checksum: %s", err)
return err
} }
layerBuffer := gopacket.NewSerializeBuffer() layerBuffer := gopacket.NewSerializeBuffer()
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload) err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
if err != nil { if err != nil {
return fmt.Errorf("serialize layers: %w", err) log.Errorf("serialize layers: %s", err)
return err
} }
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil { if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil {
return fmt.Errorf("write to raw conn: %w", err) log.Errorf("write to raw conn: %s", err)
return err
} }
return nil return nil
} }

View File

@@ -3,12 +3,11 @@
package wgproxy package wgproxy
import ( import (
"context"
"testing" "testing"
) )
func TestWGEBPFProxy_connStore(t *testing.T) { func TestWGEBPFProxy_connStore(t *testing.T) {
wgProxy := NewWGEBPFProxy(context.Background(), 1) wgProxy := NewWGEBPFProxy(1)
p, _ := wgProxy.storeTurnConn(nil) p, _ := wgProxy.storeTurnConn(nil)
if p != 1 { if p != 1 {
@@ -28,7 +27,7 @@ func TestWGEBPFProxy_connStore(t *testing.T) {
} }
func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) { func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
wgProxy := NewWGEBPFProxy(context.Background(), 1) wgProxy := NewWGEBPFProxy(1)
_, _ = wgProxy.storeTurnConn(nil) _, _ = wgProxy.storeTurnConn(nil)
wgProxy.lastUsedPort = 65535 wgProxy.lastUsedPort = 65535
@@ -44,7 +43,7 @@ func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
} }
func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) { func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) {
wgProxy := NewWGEBPFProxy(context.Background(), 1) wgProxy := NewWGEBPFProxy(1)
for i := 0; i < 65535; i++ { for i := 0; i < 65535; i++ {
_, _ = wgProxy.storeTurnConn(nil) _, _ = wgProxy.storeTurnConn(nil)

View File

@@ -21,12 +21,12 @@ type WGUserSpaceProxy struct {
} }
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy // NewWGUserSpaceProxy instantiate a user space WireGuard proxy
func NewWGUserSpaceProxy(ctx context.Context, wgPort int) *WGUserSpaceProxy { func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort) log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUserSpaceProxy{ p := &WGUserSpaceProxy{
localWGListenPort: wgPort, localWGListenPort: wgPort,
} }
p.ctx, p.cancel = context.WithCancel(ctx) p.ctx, p.cancel = context.WithCancel(context.Background())
return p return p
} }
@@ -35,7 +35,7 @@ func (p *WGUserSpaceProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
p.remoteConn = turnConn p.remoteConn = turnConn
var err error var err error
p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) p.localConn, err = nbnet.NewDialer().Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil { if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err) log.Errorf("failed dialing to local Wireguard port %s", err)
return nil, err return nil, err