Revert context changes in proxy implementations

This commit is contained in:
Zoltán Papp
2024-05-15 00:27:40 +02:00
parent 650bca7ca8
commit de2e6557ad
8 changed files with 41 additions and 71 deletions

View File

@@ -259,7 +259,7 @@ func (e *Engine) Start() error {
}
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
e.wgProxyFactory = wgproxy.NewFactory(e.clientCtx, e.config.WgPort)
e.wgProxyFactory = wgproxy.NewFactory(e.config.WgPort)
wgIface, err := e.newWgIface()
if err != nil {

View File

@@ -423,7 +423,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
var endpoint net.Addr
if isRelayCandidate(pair.Local) {
log.Debugf("setup relay connection")
conn.wgProxy = conn.wgProxyFactory.GetProxy(conn.ctx)
conn.wgProxy = conn.wgProxyFactory.GetProxy()
endpoint, err = conn.wgProxy.AddTurnConn(remoteConn)
if err != nil {
return nil, err

View File

@@ -1,7 +1,6 @@
package peer
import (
"context"
"sync"
"testing"
"time"
@@ -36,7 +35,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
}
func TestConn_GetKey(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -51,7 +50,7 @@ func TestConn_GetKey(t *testing.T) {
}
func TestConn_OnRemoteOffer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -88,7 +87,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
}
func TestConn_OnRemoteAnswer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -124,7 +123,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait()
}
func TestConn_Status(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -154,7 +153,7 @@ func TestConn_Status(t *testing.T) {
}
func TestConn_Close(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()

View File

@@ -1,17 +1,15 @@
package wgproxy
import "context"
type Factory struct {
wgPort int
ebpfProxy Proxy
}
func (w *Factory) GetProxy(ctx context.Context) Proxy {
func (w *Factory) GetProxy() Proxy {
if w.ebpfProxy != nil {
return w.ebpfProxy
}
return NewWGUserSpaceProxy(ctx, w.wgPort)
return NewWGUserSpaceProxy(w.wgPort)
}
func (w *Factory) Free() error {

View File

@@ -3,15 +3,13 @@
package wgproxy
import (
"context"
log "github.com/sirupsen/logrus"
)
func NewFactory(ctx context.Context, wgPort int) *Factory {
func NewFactory(wgPort int) *Factory {
f := &Factory{wgPort: wgPort}
ebpfProxy := NewWGEBPFProxy(ctx, wgPort)
ebpfProxy := NewWGEBPFProxy(wgPort)
err := ebpfProxy.listen()
if err != nil {
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)

View File

@@ -2,8 +2,6 @@
package wgproxy
import "context"
func NewFactory(ctx context.Context, wgPort int) *Factory {
func NewFactory(wgPort int) *Factory {
return &Factory{wgPort: wgPort}
}

View File

@@ -3,7 +3,6 @@
package wgproxy
import (
"context"
"fmt"
"io"
"net"
@@ -23,13 +22,9 @@ import (
// WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct {
ebpfManager ebpfMgr.Manager
ctx context.Context
cancel context.CancelFunc
lastUsedPort uint16
localWGListenPort int
ebpfManager ebpfMgr.Manager
lastUsedPort uint16
turnConnStore map[uint16]net.Conn
turnConnMutex sync.Mutex
@@ -39,7 +34,7 @@ type WGEBPFProxy struct {
}
// NewWGEBPFProxy create new WGEBPFProxy instance
func NewWGEBPFProxy(ctx context.Context, wgPort int) *WGEBPFProxy {
func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
log.Debugf("instantiate ebpf proxy")
wgProxy := &WGEBPFProxy{
localWGListenPort: wgPort,
@@ -47,8 +42,6 @@ func NewWGEBPFProxy(ctx context.Context, wgPort int) *WGEBPFProxy {
lastUsedPort: 0,
turnConnStore: make(map[uint16]net.Conn),
}
wgProxy.ctx, wgProxy.cancel = context.WithCancel(ctx)
return wgProxy
}
@@ -109,7 +102,6 @@ func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
// CloseConn doing nothing because this type of proxy implementation does not store the connection
func (p *WGEBPFProxy) CloseConn() error {
p.cancel()
return nil
}
@@ -139,27 +131,17 @@ func (p *WGEBPFProxy) Free() error {
func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
buf := make([]byte, 1500)
var err error
defer func() {
p.removeTurnConn(endpointPort)
}()
for {
select {
case <-p.ctx.Done():
n, err := remoteConn.Read(buf)
if err != nil {
if err != io.EOF {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
return
default:
var n int
n, err = remoteConn.Read(buf)
if err != nil {
if err != io.EOF {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
return
}
err = p.sendPkg(buf[:n], endpointPort)
if err != nil {
log.Errorf("failed to write out turn pkg to local conn: %v", err)
}
}
err = p.sendPkg(buf[:n], endpointPort)
if err != nil {
log.Errorf("failed to write out turn pkg to local conn: %v", err)
}
}
}
@@ -168,28 +150,23 @@ func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
func (p *WGEBPFProxy) proxyToRemote() {
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
n, addr, err := p.conn.ReadFromUDP(buf)
if err != nil {
log.Errorf("failed to read UDP pkg from WG: %s", err)
return
default:
n, addr, err := p.conn.ReadFromUDP(buf)
if err != nil {
log.Errorf("failed to read UDP pkg from WG: %s", err)
return
}
}
p.turnConnMutex.Lock()
conn, ok := p.turnConnStore[uint16(addr.Port)]
p.turnConnMutex.Unlock()
if !ok {
log.Infof("turn conn not found by port: %d", addr.Port)
continue
}
p.turnConnMutex.Lock()
conn, ok := p.turnConnStore[uint16(addr.Port)]
p.turnConnMutex.Unlock()
if !ok {
log.Infof("turn conn not found by port: %d", addr.Port)
continue
}
_, err = conn.Write(buf[:n])
if err != nil {
log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err)
}
_, err = conn.Write(buf[:n])
if err != nil {
log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err)
}
}
}

View File

@@ -21,12 +21,12 @@ type WGUserSpaceProxy struct {
}
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy
func NewWGUserSpaceProxy(ctx context.Context, wgPort int) *WGUserSpaceProxy {
func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUserSpaceProxy{
localWGListenPort: wgPort,
}
p.ctx, p.cancel = context.WithCancel(ctx)
p.ctx, p.cancel = context.WithCancel(context.Background())
return p
}
@@ -35,7 +35,7 @@ func (p *WGUserSpaceProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
p.remoteConn = turnConn
var err error
p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
p.localConn, err = nbnet.NewDialer().Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err)
return nil, err