mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:19 -04:00
[management,proxy] Add per-target options to reverse proxy (#5501)
This commit is contained in:
@@ -28,10 +28,12 @@ func BenchmarkServeHTTP(b *testing.B) {
|
||||
ID: rand.Text(),
|
||||
AccountID: types.AccountID(rand.Text()),
|
||||
Host: "app.example.com",
|
||||
Paths: map[string]*url.URL{
|
||||
Paths: map[string]*proxy.PathTarget{
|
||||
"/": {
|
||||
Scheme: "http",
|
||||
Host: "10.0.0.1:8080",
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "10.0.0.1:8080",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -67,10 +69,12 @@ func BenchmarkServeHTTPHostCount(b *testing.B) {
|
||||
ID: id,
|
||||
AccountID: types.AccountID(rand.Text()),
|
||||
Host: host,
|
||||
Paths: map[string]*url.URL{
|
||||
Paths: map[string]*proxy.PathTarget{
|
||||
"/": {
|
||||
Scheme: "http",
|
||||
Host: "10.0.0.1:8080",
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "10.0.0.1:8080",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -100,15 +104,17 @@ func BenchmarkServeHTTPPathCount(b *testing.B) {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
paths := make(map[string]*url.URL, pathCount)
|
||||
paths := make(map[string]*proxy.PathTarget, pathCount)
|
||||
for i := range pathCount {
|
||||
path := "/" + rand.Text()
|
||||
if int64(i) == targetIndex.Int64() {
|
||||
target = path
|
||||
}
|
||||
paths[path] = &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "10.0.0.1:" + fmt.Sprintf("%d", 8080+i),
|
||||
paths[path] = &proxy.PathTarget{
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "10.0.0.1:" + fmt.Sprintf("%d", 8080+i),
|
||||
},
|
||||
}
|
||||
}
|
||||
rp.AddMapping(proxy.Mapping{
|
||||
|
||||
@@ -80,14 +80,30 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
capturedData.SetAccountId(result.accountID)
|
||||
}
|
||||
|
||||
pt := result.target
|
||||
|
||||
if pt.SkipTLSVerify {
|
||||
ctx = roundtrip.WithSkipTLSVerify(ctx)
|
||||
}
|
||||
if pt.RequestTimeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, pt.RequestTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
rewriteMatchedPath := result.matchedPath
|
||||
if pt.PathRewrite == PathRewritePreserve {
|
||||
rewriteMatchedPath = ""
|
||||
}
|
||||
|
||||
rp := &httputil.ReverseProxy{
|
||||
Rewrite: p.rewriteFunc(result.url, result.matchedPath, result.passHostHeader),
|
||||
Rewrite: p.rewriteFunc(pt.URL, rewriteMatchedPath, result.passHostHeader, pt.PathRewrite, pt.CustomHeaders),
|
||||
Transport: p.transport,
|
||||
FlushInterval: -1,
|
||||
ErrorHandler: proxyErrorHandler,
|
||||
}
|
||||
if result.rewriteRedirects {
|
||||
rp.ModifyResponse = p.rewriteLocationFunc(result.url, result.matchedPath, r) //nolint:bodyclose
|
||||
rp.ModifyResponse = p.rewriteLocationFunc(pt.URL, rewriteMatchedPath, r) //nolint:bodyclose
|
||||
}
|
||||
rp.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
@@ -97,16 +113,22 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// forwarding headers and stripping proxy authentication credentials.
|
||||
// When passHostHeader is true, the original client Host header is preserved
|
||||
// instead of being rewritten to the backend's address.
|
||||
func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool) func(r *httputil.ProxyRequest) {
|
||||
// The pathRewrite parameter controls how the request path is transformed.
|
||||
func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool, pathRewrite PathRewriteMode, customHeaders map[string]string) func(r *httputil.ProxyRequest) {
|
||||
return func(r *httputil.ProxyRequest) {
|
||||
// Strip the matched path prefix from the incoming request path before
|
||||
// SetURL joins it with the target's base path, avoiding path duplication.
|
||||
if matchedPath != "" && matchedPath != "/" {
|
||||
r.Out.URL.Path = strings.TrimPrefix(r.Out.URL.Path, matchedPath)
|
||||
if r.Out.URL.Path == "" {
|
||||
r.Out.URL.Path = "/"
|
||||
switch pathRewrite {
|
||||
case PathRewritePreserve:
|
||||
// Keep the full original request path as-is.
|
||||
default:
|
||||
if matchedPath != "" && matchedPath != "/" {
|
||||
// Strip the matched path prefix from the incoming request path before
|
||||
// SetURL joins it with the target's base path, avoiding path duplication.
|
||||
r.Out.URL.Path = strings.TrimPrefix(r.Out.URL.Path, matchedPath)
|
||||
if r.Out.URL.Path == "" {
|
||||
r.Out.URL.Path = "/"
|
||||
}
|
||||
r.Out.URL.RawPath = ""
|
||||
}
|
||||
r.Out.URL.RawPath = ""
|
||||
}
|
||||
|
||||
r.SetURL(target)
|
||||
@@ -116,6 +138,10 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost
|
||||
r.Out.Host = target.Host
|
||||
}
|
||||
|
||||
for k, v := range customHeaders {
|
||||
r.Out.Header.Set(k, v)
|
||||
}
|
||||
|
||||
clientIP := extractClientIP(r.In.RemoteAddr)
|
||||
|
||||
if IsTrustedProxy(clientIP, p.trustedProxies) {
|
||||
|
||||
@@ -28,7 +28,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
|
||||
t.Run("rewrites host to backend by default", func(t *testing.T) {
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -37,7 +37,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("preserves original host when passHostHeader is true", func(t *testing.T) {
|
||||
rewrite := p.rewriteFunc(target, "", true)
|
||||
rewrite := p.rewriteFunc(target, "", true, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -52,7 +52,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
|
||||
func TestRewriteFunc_XForwardedForStripping(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
t.Run("sets X-Forwarded-For from direct connection IP", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
@@ -89,7 +89,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
|
||||
t.Run("sets X-Forwarded-Host to original host", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://myapp.example.com:8443/path", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -99,7 +99,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
|
||||
t.Run("sets X-Forwarded-Port from explicit host port", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://example.com:8443/path", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -109,7 +109,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
|
||||
t.Run("defaults X-Forwarded-Port to 443 for https", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
|
||||
pr.In.TLS = &tls.ConnectionState{}
|
||||
|
||||
@@ -120,7 +120,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
|
||||
t.Run("defaults X-Forwarded-Port to 80 for http", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -130,7 +130,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
|
||||
t.Run("auto detects https from TLS", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
|
||||
pr.In.TLS = &tls.ConnectionState{}
|
||||
|
||||
@@ -141,7 +141,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
|
||||
t.Run("auto detects http without TLS", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -151,7 +151,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
|
||||
t.Run("forced proto overrides TLS detection", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "https"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
// No TLS, but forced to https
|
||||
|
||||
@@ -162,7 +162,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
|
||||
t.Run("forced http proto", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "http"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
|
||||
pr.In.TLS = &tls.ConnectionState{}
|
||||
|
||||
@@ -175,7 +175,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
t.Run("strips nb_session cookie", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
@@ -220,7 +220,7 @@ func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
|
||||
func TestRewriteFunc_SessionTokenQueryStripping(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
t.Run("strips session_token query parameter", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/callback?session_token=secret123&other=keep", "1.2.3.4:5000")
|
||||
@@ -248,7 +248,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
|
||||
|
||||
t.Run("rewrites URL to target with path prefix", func(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080/app")
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/somepath", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -261,7 +261,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
|
||||
|
||||
t.Run("strips matched path prefix to avoid duplication", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://backend.example.org:443/app")
|
||||
rewrite := p.rewriteFunc(target, "/app", false)
|
||||
rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/app", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -274,7 +274,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
|
||||
|
||||
t.Run("strips matched prefix and preserves subpath", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://backend.example.org:443/app")
|
||||
rewrite := p.rewriteFunc(target, "/app", false)
|
||||
rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/app/article/123", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -332,7 +332,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("appends to X-Forwarded-For", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
|
||||
@@ -344,7 +344,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("preserves upstream X-Real-IP", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
|
||||
@@ -357,7 +357,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("resolves X-Real-IP from XFF when not set by upstream", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2")
|
||||
@@ -370,7 +370,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("preserves upstream X-Forwarded-Host", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://proxy.internal/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-Host", "original.example.com")
|
||||
@@ -382,7 +382,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("preserves upstream X-Forwarded-Proto", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-Proto", "https")
|
||||
@@ -394,7 +394,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("preserves upstream X-Forwarded-Port", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-Port", "8443")
|
||||
@@ -406,7 +406,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("falls back to local proto when upstream does not set it", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "https", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
|
||||
@@ -418,7 +418,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("sets X-Forwarded-Host from request when upstream does not set it", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
|
||||
@@ -429,7 +429,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("untrusted RemoteAddr strips headers even with trusted list", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1")
|
||||
@@ -454,7 +454,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("empty trusted list behaves as untrusted", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: nil}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
|
||||
@@ -467,7 +467,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("XFF starts fresh when trusted proxy has no upstream XFF", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
|
||||
@@ -490,7 +490,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
|
||||
t.Run("path prefix baked into target URL is a no-op", func(t *testing.T) {
|
||||
// Management builds: path="/heise", target="https://heise.de:443/heise"
|
||||
target, _ := url.Parse("https://heise.de:443/heise")
|
||||
rewrite := p.rewriteFunc(target, "/heise", false)
|
||||
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -501,7 +501,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
|
||||
|
||||
t.Run("subpath under prefix also preserved", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://heise.de:443/heise")
|
||||
rewrite := p.rewriteFunc(target, "/heise", false)
|
||||
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -513,7 +513,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
|
||||
// What the behavior WOULD be if target URL had no path (true stripping)
|
||||
t.Run("target without path prefix gives true stripping", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://heise.de:443")
|
||||
rewrite := p.rewriteFunc(target, "/heise", false)
|
||||
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -524,7 +524,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
|
||||
|
||||
t.Run("target without path prefix strips and preserves subpath", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://heise.de:443")
|
||||
rewrite := p.rewriteFunc(target, "/heise", false)
|
||||
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -536,7 +536,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
|
||||
// Root path "/" — no stripping expected
|
||||
t.Run("root path forwards full request path unchanged", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://backend.example.com:443/")
|
||||
rewrite := p.rewriteFunc(target, "/", false)
|
||||
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -546,6 +546,82 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteFunc_PreservePath(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
|
||||
t.Run("preserve keeps full request path", func(t *testing.T) {
|
||||
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/api/users/123", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "/api/users/123", pr.Out.URL.Path,
|
||||
"preserve should keep the full original request path")
|
||||
})
|
||||
|
||||
t.Run("preserve with root matchedPath", func(t *testing.T) {
|
||||
rewrite := p.rewriteFunc(target, "/", false, PathRewritePreserve, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/anything", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "/anything", pr.Out.URL.Path)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteFunc_CustomHeaders(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
|
||||
t.Run("injects custom headers", func(t *testing.T) {
|
||||
headers := map[string]string{
|
||||
"X-Custom-Auth": "token-abc",
|
||||
"X-Env": "production",
|
||||
}
|
||||
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "token-abc", pr.Out.Header.Get("X-Custom-Auth"))
|
||||
assert.Equal(t, "production", pr.Out.Header.Get("X-Env"))
|
||||
})
|
||||
|
||||
t.Run("nil customHeaders is fine", func(t *testing.T) {
|
||||
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "backend.internal:8080", pr.Out.Host)
|
||||
})
|
||||
|
||||
t.Run("custom headers override existing request headers", func(t *testing.T) {
|
||||
headers := map[string]string{"X-Override": "new-value"}
|
||||
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
pr.In.Header.Set("X-Override", "old-value")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "new-value", pr.Out.Header.Get("X-Override"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteFunc_PreservePathWithCustomHeaders(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
|
||||
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, map[string]string{"X-Via": "proxy"})
|
||||
pr := newProxyRequest(t, "http://example.com/api/deep/path", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "/api/deep/path", pr.Out.URL.Path, "preserve should keep the full original path")
|
||||
assert.Equal(t, "proxy", pr.Out.Header.Get("X-Via"), "custom header should be set")
|
||||
}
|
||||
|
||||
func TestRewriteLocationFunc(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
newProxy := func(proto string) *ReverseProxy { return &ReverseProxy{forwardedProto: proto} }
|
||||
|
||||
@@ -6,21 +6,41 @@ import (
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
// PathRewriteMode controls how the request path is rewritten before forwarding.
|
||||
type PathRewriteMode int
|
||||
|
||||
const (
|
||||
// PathRewriteDefault strips the matched prefix and joins with the target path.
|
||||
PathRewriteDefault PathRewriteMode = iota
|
||||
// PathRewritePreserve keeps the full original request path as-is.
|
||||
PathRewritePreserve
|
||||
)
|
||||
|
||||
// PathTarget holds a backend URL and per-target behavioral options.
|
||||
type PathTarget struct {
|
||||
URL *url.URL
|
||||
SkipTLSVerify bool
|
||||
RequestTimeout time.Duration
|
||||
PathRewrite PathRewriteMode
|
||||
CustomHeaders map[string]string
|
||||
}
|
||||
|
||||
type Mapping struct {
|
||||
ID string
|
||||
AccountID types.AccountID
|
||||
Host string
|
||||
Paths map[string]*url.URL
|
||||
Paths map[string]*PathTarget
|
||||
PassHostHeader bool
|
||||
RewriteRedirects bool
|
||||
}
|
||||
|
||||
type targetResult struct {
|
||||
url *url.URL
|
||||
target *PathTarget
|
||||
matchedPath string
|
||||
serviceID string
|
||||
accountID types.AccountID
|
||||
@@ -55,10 +75,14 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bo
|
||||
|
||||
for _, path := range paths {
|
||||
if strings.HasPrefix(req.URL.Path, path) {
|
||||
target := m.Paths[path]
|
||||
p.logger.Debugf("matched host: %s, path: %s -> %s", host, path, target)
|
||||
pt := m.Paths[path]
|
||||
if pt == nil || pt.URL == nil {
|
||||
p.logger.Warnf("invalid mapping for host: %s, path: %s (nil target)", host, path)
|
||||
continue
|
||||
}
|
||||
p.logger.Debugf("matched host: %s, path: %s -> %s", host, path, pt.URL)
|
||||
return targetResult{
|
||||
url: target,
|
||||
target: pt,
|
||||
matchedPath: path,
|
||||
serviceID: m.ID,
|
||||
accountID: m.AccountID,
|
||||
|
||||
32
proxy/internal/roundtrip/context_test.go
Normal file
32
proxy/internal/roundtrip/context_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package roundtrip
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
func TestAccountIDContext(t *testing.T) {
|
||||
t.Run("returns empty when missing", func(t *testing.T) {
|
||||
assert.Equal(t, types.AccountID(""), AccountIDFromContext(context.Background()))
|
||||
})
|
||||
|
||||
t.Run("round-trips value", func(t *testing.T) {
|
||||
ctx := WithAccountID(context.Background(), "acc-123")
|
||||
assert.Equal(t, types.AccountID("acc-123"), AccountIDFromContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func TestSkipTLSVerifyContext(t *testing.T) {
|
||||
t.Run("false by default", func(t *testing.T) {
|
||||
assert.False(t, skipTLSVerifyFromContext(context.Background()))
|
||||
})
|
||||
|
||||
t.Run("true when set", func(t *testing.T) {
|
||||
ctx := WithSkipTLSVerify(context.Background())
|
||||
assert.True(t, skipTLSVerifyFromContext(ctx))
|
||||
})
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package roundtrip
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -52,9 +53,12 @@ type domainNotification struct {
|
||||
type clientEntry struct {
|
||||
client *embed.Client
|
||||
transport *http.Transport
|
||||
domains map[domain.Domain]domainInfo
|
||||
createdAt time.Time
|
||||
started bool
|
||||
// insecureTransport is a clone of transport with TLS verification disabled,
|
||||
// used when per-target skip_tls_verify is set.
|
||||
insecureTransport *http.Transport
|
||||
domains map[domain.Domain]domainInfo
|
||||
createdAt time.Time
|
||||
started bool
|
||||
// Per-backend in-flight limiting keyed by target host:port.
|
||||
// TODO: clean up stale entries when backend targets change.
|
||||
inflightMu sync.Mutex
|
||||
@@ -130,6 +134,9 @@ type ClientDebugInfo struct {
|
||||
// accountIDContextKey is the context key for storing the account ID.
|
||||
type accountIDContextKey struct{}
|
||||
|
||||
// skipTLSVerifyContextKey is the context key for requesting insecure TLS.
|
||||
type skipTLSVerifyContextKey struct{}
|
||||
|
||||
// AddPeer registers a domain for an account. If the account doesn't have a client yet,
|
||||
// one is created by authenticating with the management server using the provided token.
|
||||
// Multiple domains can share the same client.
|
||||
@@ -249,27 +256,33 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
// Create a transport using the client dialer. We do this instead of using
|
||||
// the client's HTTPClient to avoid issues with request validation that do
|
||||
// not work with reverse proxied requests.
|
||||
transport := &http.Transport{
|
||||
DialContext: client.DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: n.transportCfg.maxIdleConns,
|
||||
MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost,
|
||||
MaxConnsPerHost: n.transportCfg.maxConnsPerHost,
|
||||
IdleConnTimeout: n.transportCfg.idleConnTimeout,
|
||||
TLSHandshakeTimeout: n.transportCfg.tlsHandshakeTimeout,
|
||||
ExpectContinueTimeout: n.transportCfg.expectContinueTimeout,
|
||||
ResponseHeaderTimeout: n.transportCfg.responseHeaderTimeout,
|
||||
WriteBufferSize: n.transportCfg.writeBufferSize,
|
||||
ReadBufferSize: n.transportCfg.readBufferSize,
|
||||
DisableCompression: n.transportCfg.disableCompression,
|
||||
}
|
||||
|
||||
insecureTransport := transport.Clone()
|
||||
insecureTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec
|
||||
|
||||
return &clientEntry{
|
||||
client: client,
|
||||
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
|
||||
transport: &http.Transport{
|
||||
DialContext: client.DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: n.transportCfg.maxIdleConns,
|
||||
MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost,
|
||||
MaxConnsPerHost: n.transportCfg.maxConnsPerHost,
|
||||
IdleConnTimeout: n.transportCfg.idleConnTimeout,
|
||||
TLSHandshakeTimeout: n.transportCfg.tlsHandshakeTimeout,
|
||||
ExpectContinueTimeout: n.transportCfg.expectContinueTimeout,
|
||||
ResponseHeaderTimeout: n.transportCfg.responseHeaderTimeout,
|
||||
WriteBufferSize: n.transportCfg.writeBufferSize,
|
||||
ReadBufferSize: n.transportCfg.readBufferSize,
|
||||
DisableCompression: n.transportCfg.disableCompression,
|
||||
},
|
||||
createdAt: time.Now(),
|
||||
started: false,
|
||||
inflightMap: make(map[backendKey]chan struct{}),
|
||||
maxInflight: n.transportCfg.maxInflight,
|
||||
client: client,
|
||||
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
|
||||
transport: transport,
|
||||
insecureTransport: insecureTransport,
|
||||
createdAt: time.Now(),
|
||||
started: false,
|
||||
inflightMap: make(map[backendKey]chan struct{}),
|
||||
maxInflight: n.transportCfg.maxInflight,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -373,6 +386,7 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d d
|
||||
|
||||
client := entry.client
|
||||
transport := entry.transport
|
||||
insecureTransport := entry.insecureTransport
|
||||
delete(n.clients, accountID)
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
@@ -387,6 +401,7 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d d
|
||||
}
|
||||
|
||||
transport.CloseIdleConnections()
|
||||
insecureTransport.CloseIdleConnections()
|
||||
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
@@ -415,6 +430,9 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
}
|
||||
client := entry.client
|
||||
transport := entry.transport
|
||||
if skipTLSVerifyFromContext(req.Context()) {
|
||||
transport = entry.insecureTransport
|
||||
}
|
||||
n.clientsMux.RUnlock()
|
||||
|
||||
release, ok := entry.acquireInflight(req.URL.Host)
|
||||
@@ -457,6 +475,7 @@ func (n *NetBird) StopAll(ctx context.Context) error {
|
||||
var merr *multierror.Error
|
||||
for accountID, entry := range n.clients {
|
||||
entry.transport.CloseIdleConnections()
|
||||
entry.insecureTransport.CloseIdleConnections()
|
||||
if err := entry.client.Stop(ctx); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
@@ -579,3 +598,14 @@ func AccountIDFromContext(ctx context.Context) types.AccountID {
|
||||
}
|
||||
return accountID
|
||||
}
|
||||
|
||||
// WithSkipTLSVerify marks the context to use an insecure transport that skips
|
||||
// TLS certificate verification for the backend connection.
|
||||
func WithSkipTLSVerify(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, skipTLSVerifyContextKey{}, true)
|
||||
}
|
||||
|
||||
func skipTLSVerifyFromContext(ctx context.Context) bool {
|
||||
v, _ := ctx.Value(skipTLSVerifyContextKey{}).(bool)
|
||||
return v
|
||||
}
|
||||
|
||||
@@ -720,7 +720,7 @@ func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping)
|
||||
}
|
||||
|
||||
func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
|
||||
paths := make(map[string]*url.URL)
|
||||
paths := make(map[string]*proxy.PathTarget)
|
||||
for _, pathMapping := range mapping.GetPath() {
|
||||
targetURL, err := url.Parse(pathMapping.GetTarget())
|
||||
if err != nil {
|
||||
@@ -734,7 +734,17 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
|
||||
}).WithError(err).Error("failed to parse target URL for path, skipping")
|
||||
continue
|
||||
}
|
||||
paths[pathMapping.GetPath()] = targetURL
|
||||
|
||||
pt := &proxy.PathTarget{URL: targetURL}
|
||||
if opts := pathMapping.GetOptions(); opts != nil {
|
||||
pt.SkipTLSVerify = opts.GetSkipTlsVerify()
|
||||
pt.PathRewrite = protoToPathRewrite(opts.GetPathRewrite())
|
||||
pt.CustomHeaders = opts.GetCustomHeaders()
|
||||
if d := opts.GetRequestTimeout(); d != nil {
|
||||
pt.RequestTimeout = d.AsDuration()
|
||||
}
|
||||
}
|
||||
paths[pathMapping.GetPath()] = pt
|
||||
}
|
||||
return proxy.Mapping{
|
||||
ID: mapping.GetId(),
|
||||
@@ -746,6 +756,15 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
|
||||
}
|
||||
}
|
||||
|
||||
func protoToPathRewrite(mode proto.PathRewriteMode) proxy.PathRewriteMode {
|
||||
switch mode {
|
||||
case proto.PathRewriteMode_PATH_REWRITE_PRESERVE:
|
||||
return proxy.PathRewritePreserve
|
||||
default:
|
||||
return proxy.PathRewriteDefault
|
||||
}
|
||||
}
|
||||
|
||||
// debugEndpointAddr returns the address for the debug endpoint.
|
||||
// If addr is empty, it defaults to localhost:8444 for security.
|
||||
func debugEndpointAddr(addr string) string {
|
||||
|
||||
Reference in New Issue
Block a user