[management,proxy] Add per-target options to reverse proxy (#5501)

This commit is contained in:
Viktor Liu
2026-03-05 17:03:26 +08:00
committed by GitHub
parent 8e7b016be2
commit e601278117
16 changed files with 1599 additions and 445 deletions

View File

@@ -28,10 +28,12 @@ func BenchmarkServeHTTP(b *testing.B) {
ID: rand.Text(),
AccountID: types.AccountID(rand.Text()),
Host: "app.example.com",
Paths: map[string]*url.URL{
Paths: map[string]*proxy.PathTarget{
"/": {
Scheme: "http",
Host: "10.0.0.1:8080",
URL: &url.URL{
Scheme: "http",
Host: "10.0.0.1:8080",
},
},
},
})
@@ -67,10 +69,12 @@ func BenchmarkServeHTTPHostCount(b *testing.B) {
ID: id,
AccountID: types.AccountID(rand.Text()),
Host: host,
Paths: map[string]*url.URL{
Paths: map[string]*proxy.PathTarget{
"/": {
Scheme: "http",
Host: "10.0.0.1:8080",
URL: &url.URL{
Scheme: "http",
Host: "10.0.0.1:8080",
},
},
},
})
@@ -100,15 +104,17 @@ func BenchmarkServeHTTPPathCount(b *testing.B) {
b.Fatal(err)
}
paths := make(map[string]*url.URL, pathCount)
paths := make(map[string]*proxy.PathTarget, pathCount)
for i := range pathCount {
path := "/" + rand.Text()
if int64(i) == targetIndex.Int64() {
target = path
}
paths[path] = &url.URL{
Scheme: "http",
Host: "10.0.0.1:" + fmt.Sprintf("%d", 8080+i),
paths[path] = &proxy.PathTarget{
URL: &url.URL{
Scheme: "http",
Host: "10.0.0.1:" + fmt.Sprintf("%d", 8080+i),
},
}
}
rp.AddMapping(proxy.Mapping{

View File

@@ -80,14 +80,30 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
capturedData.SetAccountId(result.accountID)
}
pt := result.target
if pt.SkipTLSVerify {
ctx = roundtrip.WithSkipTLSVerify(ctx)
}
if pt.RequestTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, pt.RequestTimeout)
defer cancel()
}
rewriteMatchedPath := result.matchedPath
if pt.PathRewrite == PathRewritePreserve {
rewriteMatchedPath = ""
}
rp := &httputil.ReverseProxy{
Rewrite: p.rewriteFunc(result.url, result.matchedPath, result.passHostHeader),
Rewrite: p.rewriteFunc(pt.URL, rewriteMatchedPath, result.passHostHeader, pt.PathRewrite, pt.CustomHeaders),
Transport: p.transport,
FlushInterval: -1,
ErrorHandler: proxyErrorHandler,
}
if result.rewriteRedirects {
rp.ModifyResponse = p.rewriteLocationFunc(result.url, result.matchedPath, r) //nolint:bodyclose
rp.ModifyResponse = p.rewriteLocationFunc(pt.URL, rewriteMatchedPath, r) //nolint:bodyclose
}
rp.ServeHTTP(w, r.WithContext(ctx))
}
@@ -97,16 +113,22 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// forwarding headers and stripping proxy authentication credentials.
// When passHostHeader is true, the original client Host header is preserved
// instead of being rewritten to the backend's address.
func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool) func(r *httputil.ProxyRequest) {
// The pathRewrite parameter controls how the request path is transformed.
func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool, pathRewrite PathRewriteMode, customHeaders map[string]string) func(r *httputil.ProxyRequest) {
return func(r *httputil.ProxyRequest) {
// Strip the matched path prefix from the incoming request path before
// SetURL joins it with the target's base path, avoiding path duplication.
if matchedPath != "" && matchedPath != "/" {
r.Out.URL.Path = strings.TrimPrefix(r.Out.URL.Path, matchedPath)
if r.Out.URL.Path == "" {
r.Out.URL.Path = "/"
switch pathRewrite {
case PathRewritePreserve:
// Keep the full original request path as-is.
default:
if matchedPath != "" && matchedPath != "/" {
// Strip the matched path prefix from the incoming request path before
// SetURL joins it with the target's base path, avoiding path duplication.
r.Out.URL.Path = strings.TrimPrefix(r.Out.URL.Path, matchedPath)
if r.Out.URL.Path == "" {
r.Out.URL.Path = "/"
}
r.Out.URL.RawPath = ""
}
r.Out.URL.RawPath = ""
}
r.SetURL(target)
@@ -116,6 +138,10 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost
r.Out.Host = target.Host
}
for k, v := range customHeaders {
r.Out.Header.Set(k, v)
}
clientIP := extractClientIP(r.In.RemoteAddr)
if IsTrustedProxy(clientIP, p.trustedProxies) {

View File

@@ -28,7 +28,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
t.Run("rewrites host to backend by default", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "", false)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
rewrite(pr)
@@ -37,7 +37,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
})
t.Run("preserves original host when passHostHeader is true", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "", true)
rewrite := p.rewriteFunc(target, "", true, PathRewriteDefault, nil)
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
rewrite(pr)
@@ -52,7 +52,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
func TestRewriteFunc_XForwardedForStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
t.Run("sets X-Forwarded-For from direct connection IP", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
@@ -89,7 +89,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("sets X-Forwarded-Host to original host", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://myapp.example.com:8443/path", "1.2.3.4:5000")
rewrite(pr)
@@ -99,7 +99,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("sets X-Forwarded-Port from explicit host port", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com:8443/path", "1.2.3.4:5000")
rewrite(pr)
@@ -109,7 +109,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("defaults X-Forwarded-Port to 443 for https", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{}
@@ -120,7 +120,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("defaults X-Forwarded-Port to 80 for http", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
@@ -130,7 +130,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("auto detects https from TLS", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{}
@@ -141,7 +141,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("auto detects http without TLS", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
@@ -151,7 +151,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("forced proto overrides TLS detection", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "https"}
rewrite := p.rewriteFunc(target, "", false)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
// No TLS, but forced to https
@@ -162,7 +162,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("forced http proto", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "http"}
rewrite := p.rewriteFunc(target, "", false)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{}
@@ -175,7 +175,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
t.Run("strips nb_session cookie", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
@@ -220,7 +220,7 @@ func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
func TestRewriteFunc_SessionTokenQueryStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
t.Run("strips session_token query parameter", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/callback?session_token=secret123&other=keep", "1.2.3.4:5000")
@@ -248,7 +248,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
t.Run("rewrites URL to target with path prefix", func(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080/app")
rewrite := p.rewriteFunc(target, "", false)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/somepath", "1.2.3.4:5000")
rewrite(pr)
@@ -261,7 +261,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
t.Run("strips matched path prefix to avoid duplication", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.org:443/app")
rewrite := p.rewriteFunc(target, "/app", false)
rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/app", "1.2.3.4:5000")
rewrite(pr)
@@ -274,7 +274,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
t.Run("strips matched prefix and preserves subpath", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.org:443/app")
rewrite := p.rewriteFunc(target, "/app", false)
rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/app/article/123", "1.2.3.4:5000")
rewrite(pr)
@@ -332,7 +332,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("appends to X-Forwarded-For", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
@@ -344,7 +344,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("preserves upstream X-Real-IP", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
@@ -357,7 +357,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("resolves X-Real-IP from XFF when not set by upstream", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2")
@@ -370,7 +370,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("preserves upstream X-Forwarded-Host", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://proxy.internal/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Host", "original.example.com")
@@ -382,7 +382,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("preserves upstream X-Forwarded-Proto", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Proto", "https")
@@ -394,7 +394,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("preserves upstream X-Forwarded-Port", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Port", "8443")
@@ -406,7 +406,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("falls back to local proto when upstream does not set it", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "https", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
@@ -418,7 +418,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("sets X-Forwarded-Host from request when upstream does not set it", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
@@ -429,7 +429,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("untrusted RemoteAddr strips headers even with trusted list", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1")
@@ -454,7 +454,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("empty trusted list behaves as untrusted", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: nil}
rewrite := p.rewriteFunc(target, "", false)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
@@ -467,7 +467,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("XFF starts fresh when trusted proxy has no upstream XFF", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
@@ -490,7 +490,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
t.Run("path prefix baked into target URL is a no-op", func(t *testing.T) {
// Management builds: path="/heise", target="https://heise.de:443/heise"
target, _ := url.Parse("https://heise.de:443/heise")
rewrite := p.rewriteFunc(target, "/heise", false)
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr)
@@ -501,7 +501,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
t.Run("subpath under prefix also preserved", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443/heise")
rewrite := p.rewriteFunc(target, "/heise", false)
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
rewrite(pr)
@@ -513,7 +513,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
// What the behavior WOULD be if target URL had no path (true stripping)
t.Run("target without path prefix gives true stripping", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443")
rewrite := p.rewriteFunc(target, "/heise", false)
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr)
@@ -524,7 +524,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
t.Run("target without path prefix strips and preserves subpath", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443")
rewrite := p.rewriteFunc(target, "/heise", false)
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
rewrite(pr)
@@ -536,7 +536,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
// Root path "/" — no stripping expected
t.Run("root path forwards full request path unchanged", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.com:443/")
rewrite := p.rewriteFunc(target, "/", false)
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr)
@@ -546,6 +546,82 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
})
}
func TestRewriteFunc_PreservePath(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
target, _ := url.Parse("http://backend.internal:8080")
t.Run("preserve keeps full request path", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, nil)
pr := newProxyRequest(t, "http://example.com/api/users/123", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/api/users/123", pr.Out.URL.Path,
"preserve should keep the full original request path")
})
t.Run("preserve with root matchedPath", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "/", false, PathRewritePreserve, nil)
pr := newProxyRequest(t, "http://example.com/anything", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/anything", pr.Out.URL.Path)
})
}
func TestRewriteFunc_CustomHeaders(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
target, _ := url.Parse("http://backend.internal:8080")
t.Run("injects custom headers", func(t *testing.T) {
headers := map[string]string{
"X-Custom-Auth": "token-abc",
"X-Env": "production",
}
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "token-abc", pr.Out.Header.Get("X-Custom-Auth"))
assert.Equal(t, "production", pr.Out.Header.Get("X-Env"))
})
t.Run("nil customHeaders is fine", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "backend.internal:8080", pr.Out.Host)
})
t.Run("custom headers override existing request headers", func(t *testing.T) {
headers := map[string]string{"X-Override": "new-value"}
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
pr.In.Header.Set("X-Override", "old-value")
rewrite(pr)
assert.Equal(t, "new-value", pr.Out.Header.Get("X-Override"))
})
}
func TestRewriteFunc_PreservePathWithCustomHeaders(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
target, _ := url.Parse("http://backend.internal:8080")
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, map[string]string{"X-Via": "proxy"})
pr := newProxyRequest(t, "http://example.com/api/deep/path", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/api/deep/path", pr.Out.URL.Path, "preserve should keep the full original path")
assert.Equal(t, "proxy", pr.Out.Header.Get("X-Via"), "custom header should be set")
}
func TestRewriteLocationFunc(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
newProxy := func(proto string) *ReverseProxy { return &ReverseProxy{forwardedProto: proto} }

View File

@@ -6,21 +6,41 @@ import (
"net/url"
"sort"
"strings"
"time"
"github.com/netbirdio/netbird/proxy/internal/types"
)
// PathRewriteMode controls how the request path is rewritten before forwarding.
type PathRewriteMode int
const (
// PathRewriteDefault strips the matched prefix and joins with the target path.
PathRewriteDefault PathRewriteMode = iota
// PathRewritePreserve keeps the full original request path as-is.
PathRewritePreserve
)
// PathTarget holds a backend URL and per-target behavioral options.
type PathTarget struct {
URL *url.URL
SkipTLSVerify bool
RequestTimeout time.Duration
PathRewrite PathRewriteMode
CustomHeaders map[string]string
}
type Mapping struct {
ID string
AccountID types.AccountID
Host string
Paths map[string]*url.URL
Paths map[string]*PathTarget
PassHostHeader bool
RewriteRedirects bool
}
type targetResult struct {
url *url.URL
target *PathTarget
matchedPath string
serviceID string
accountID types.AccountID
@@ -55,10 +75,14 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bo
for _, path := range paths {
if strings.HasPrefix(req.URL.Path, path) {
target := m.Paths[path]
p.logger.Debugf("matched host: %s, path: %s -> %s", host, path, target)
pt := m.Paths[path]
if pt == nil || pt.URL == nil {
p.logger.Warnf("invalid mapping for host: %s, path: %s (nil target)", host, path)
continue
}
p.logger.Debugf("matched host: %s, path: %s -> %s", host, path, pt.URL)
return targetResult{
url: target,
target: pt,
matchedPath: path,
serviceID: m.ID,
accountID: m.AccountID,

View File

@@ -0,0 +1,32 @@
package roundtrip
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/proxy/internal/types"
)
func TestAccountIDContext(t *testing.T) {
t.Run("returns empty when missing", func(t *testing.T) {
assert.Equal(t, types.AccountID(""), AccountIDFromContext(context.Background()))
})
t.Run("round-trips value", func(t *testing.T) {
ctx := WithAccountID(context.Background(), "acc-123")
assert.Equal(t, types.AccountID("acc-123"), AccountIDFromContext(ctx))
})
}
func TestSkipTLSVerifyContext(t *testing.T) {
t.Run("false by default", func(t *testing.T) {
assert.False(t, skipTLSVerifyFromContext(context.Background()))
})
t.Run("true when set", func(t *testing.T) {
ctx := WithSkipTLSVerify(context.Background())
assert.True(t, skipTLSVerifyFromContext(ctx))
})
}

View File

@@ -2,6 +2,7 @@ package roundtrip
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net/http"
@@ -52,9 +53,12 @@ type domainNotification struct {
type clientEntry struct {
client *embed.Client
transport *http.Transport
domains map[domain.Domain]domainInfo
createdAt time.Time
started bool
// insecureTransport is a clone of transport with TLS verification disabled,
// used when per-target skip_tls_verify is set.
insecureTransport *http.Transport
domains map[domain.Domain]domainInfo
createdAt time.Time
started bool
// Per-backend in-flight limiting keyed by target host:port.
// TODO: clean up stale entries when backend targets change.
inflightMu sync.Mutex
@@ -130,6 +134,9 @@ type ClientDebugInfo struct {
// accountIDContextKey is the context key for storing the account ID.
type accountIDContextKey struct{}
// skipTLSVerifyContextKey is the context key for requesting insecure TLS.
type skipTLSVerifyContextKey struct{}
// AddPeer registers a domain for an account. If the account doesn't have a client yet,
// one is created by authenticating with the management server using the provided token.
// Multiple domains can share the same client.
@@ -249,27 +256,33 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
// Create a transport using the client dialer. We do this instead of using
// the client's HTTPClient to avoid issues with request validation that do
// not work with reverse proxied requests.
transport := &http.Transport{
DialContext: client.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: n.transportCfg.maxIdleConns,
MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost,
MaxConnsPerHost: n.transportCfg.maxConnsPerHost,
IdleConnTimeout: n.transportCfg.idleConnTimeout,
TLSHandshakeTimeout: n.transportCfg.tlsHandshakeTimeout,
ExpectContinueTimeout: n.transportCfg.expectContinueTimeout,
ResponseHeaderTimeout: n.transportCfg.responseHeaderTimeout,
WriteBufferSize: n.transportCfg.writeBufferSize,
ReadBufferSize: n.transportCfg.readBufferSize,
DisableCompression: n.transportCfg.disableCompression,
}
insecureTransport := transport.Clone()
insecureTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec
return &clientEntry{
client: client,
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
transport: &http.Transport{
DialContext: client.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: n.transportCfg.maxIdleConns,
MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost,
MaxConnsPerHost: n.transportCfg.maxConnsPerHost,
IdleConnTimeout: n.transportCfg.idleConnTimeout,
TLSHandshakeTimeout: n.transportCfg.tlsHandshakeTimeout,
ExpectContinueTimeout: n.transportCfg.expectContinueTimeout,
ResponseHeaderTimeout: n.transportCfg.responseHeaderTimeout,
WriteBufferSize: n.transportCfg.writeBufferSize,
ReadBufferSize: n.transportCfg.readBufferSize,
DisableCompression: n.transportCfg.disableCompression,
},
createdAt: time.Now(),
started: false,
inflightMap: make(map[backendKey]chan struct{}),
maxInflight: n.transportCfg.maxInflight,
client: client,
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
transport: transport,
insecureTransport: insecureTransport,
createdAt: time.Now(),
started: false,
inflightMap: make(map[backendKey]chan struct{}),
maxInflight: n.transportCfg.maxInflight,
}, nil
}
@@ -373,6 +386,7 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d d
client := entry.client
transport := entry.transport
insecureTransport := entry.insecureTransport
delete(n.clients, accountID)
n.clientsMux.Unlock()
@@ -387,6 +401,7 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d d
}
transport.CloseIdleConnections()
insecureTransport.CloseIdleConnections()
if err := client.Stop(ctx); err != nil {
n.logger.WithFields(log.Fields{
@@ -415,6 +430,9 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
}
client := entry.client
transport := entry.transport
if skipTLSVerifyFromContext(req.Context()) {
transport = entry.insecureTransport
}
n.clientsMux.RUnlock()
release, ok := entry.acquireInflight(req.URL.Host)
@@ -457,6 +475,7 @@ func (n *NetBird) StopAll(ctx context.Context) error {
var merr *multierror.Error
for accountID, entry := range n.clients {
entry.transport.CloseIdleConnections()
entry.insecureTransport.CloseIdleConnections()
if err := entry.client.Stop(ctx); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
@@ -579,3 +598,14 @@ func AccountIDFromContext(ctx context.Context) types.AccountID {
}
return accountID
}
// WithSkipTLSVerify marks the context to use an insecure transport that skips
// TLS certificate verification for the backend connection.
func WithSkipTLSVerify(ctx context.Context) context.Context {
return context.WithValue(ctx, skipTLSVerifyContextKey{}, true)
}
func skipTLSVerifyFromContext(ctx context.Context) bool {
v, _ := ctx.Value(skipTLSVerifyContextKey{}).(bool)
return v
}

View File

@@ -720,7 +720,7 @@ func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping)
}
func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
paths := make(map[string]*url.URL)
paths := make(map[string]*proxy.PathTarget)
for _, pathMapping := range mapping.GetPath() {
targetURL, err := url.Parse(pathMapping.GetTarget())
if err != nil {
@@ -734,7 +734,17 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
}).WithError(err).Error("failed to parse target URL for path, skipping")
continue
}
paths[pathMapping.GetPath()] = targetURL
pt := &proxy.PathTarget{URL: targetURL}
if opts := pathMapping.GetOptions(); opts != nil {
pt.SkipTLSVerify = opts.GetSkipTlsVerify()
pt.PathRewrite = protoToPathRewrite(opts.GetPathRewrite())
pt.CustomHeaders = opts.GetCustomHeaders()
if d := opts.GetRequestTimeout(); d != nil {
pt.RequestTimeout = d.AsDuration()
}
}
paths[pathMapping.GetPath()] = pt
}
return proxy.Mapping{
ID: mapping.GetId(),
@@ -746,6 +756,15 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
}
}
func protoToPathRewrite(mode proto.PathRewriteMode) proxy.PathRewriteMode {
switch mode {
case proto.PathRewriteMode_PATH_REWRITE_PRESERVE:
return proxy.PathRewritePreserve
default:
return proxy.PathRewriteDefault
}
}
// debugEndpointAddr returns the address for the debug endpoint.
// If addr is empty, it defaults to localhost:8444 for security.
func debugEndpointAddr(addr string) string {