[management,proxy] Add per-target options to reverse proxy (#5501)

This commit is contained in:
Viktor Liu
2026-03-05 17:03:26 +08:00
committed by GitHub
parent 8e7b016be2
commit e601278117
16 changed files with 1599 additions and 445 deletions

View File

@@ -73,7 +73,10 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
}
service := new(rpservice.Service)
service.FromAPIRequest(&req, userAuth.AccountId)
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
if err = service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
@@ -132,7 +135,10 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
service := new(rpservice.Service)
service.ID = serviceID
service.FromAPIRequest(&req, userAuth.AccountId)
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
if err = service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)

View File

@@ -6,13 +6,16 @@ import (
"fmt"
"math/big"
"net"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/shared/hash/argon2id"
@@ -49,17 +52,25 @@ const (
SourceEphemeral = "ephemeral"
)
type TargetOptions struct {
SkipTLSVerify bool `json:"skip_tls_verify"`
RequestTimeout time.Duration `json:"request_timeout,omitempty"`
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
}
type Target struct {
ID uint `gorm:"primaryKey" json:"-"`
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
Path *string `json:"path,omitempty"`
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
Port int `gorm:"index:idx_target_port" json:"port"`
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
ID uint `gorm:"primaryKey" json:"-"`
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
Path *string `json:"path,omitempty"`
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
Port int `gorm:"index:idx_target_port" json:"port"`
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
Options TargetOptions `gorm:"embedded" json:"options"`
}
type PasswordAuthConfig struct {
@@ -194,7 +205,7 @@ func (s *Service) ToAPIResponse() *api.Service {
// Convert internal targets to API targets
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
for _, target := range s.Targets {
apiTargets = append(apiTargets, api.ServiceTarget{
st := api.ServiceTarget{
Path: target.Path,
Host: &target.Host,
Port: target.Port,
@@ -202,7 +213,9 @@ func (s *Service) ToAPIResponse() *api.Service {
TargetId: target.TargetId,
TargetType: api.ServiceTargetTargetType(target.TargetType),
Enabled: target.Enabled,
})
}
st.Options = targetOptionsToAPI(target.Options)
apiTargets = append(apiTargets, st)
}
meta := api.ServiceMeta{
@@ -256,10 +269,14 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
if target.Path != nil {
path = *target.Path
}
pathMappings = append(pathMappings, &proto.PathMapping{
pm := &proto.PathMapping{
Path: path,
Target: targetURL.String(),
})
}
pm.Options = targetOptionsToProto(target.Options)
pathMappings = append(pathMappings, pm)
}
auth := &proto.Authentication{
@@ -312,13 +329,87 @@ func isDefaultPort(scheme string, port int) bool {
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
}
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
// PathRewriteMode controls how the request path is rewritten before forwarding.
type PathRewriteMode string
const (
PathRewritePreserve PathRewriteMode = "preserve"
)
func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode {
switch mode {
case PathRewritePreserve:
return proto.PathRewriteMode_PATH_REWRITE_PRESERVE
default:
return proto.PathRewriteMode_PATH_REWRITE_DEFAULT
}
}
func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 {
return nil
}
apiOpts := &api.ServiceTargetOptions{}
if opts.SkipTLSVerify {
apiOpts.SkipTlsVerify = &opts.SkipTLSVerify
}
if opts.RequestTimeout != 0 {
s := opts.RequestTimeout.String()
apiOpts.RequestTimeout = &s
}
if opts.PathRewrite != "" {
pr := api.ServiceTargetOptionsPathRewrite(opts.PathRewrite)
apiOpts.PathRewrite = &pr
}
if len(opts.CustomHeaders) > 0 {
apiOpts.CustomHeaders = &opts.CustomHeaders
}
return apiOpts
}
func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 && len(opts.CustomHeaders) == 0 {
return nil
}
popts := &proto.PathTargetOptions{
SkipTlsVerify: opts.SkipTLSVerify,
PathRewrite: pathRewriteToProto(opts.PathRewrite),
CustomHeaders: opts.CustomHeaders,
}
if opts.RequestTimeout != 0 {
popts.RequestTimeout = durationpb.New(opts.RequestTimeout)
}
return popts
}
func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions, error) {
var opts TargetOptions
if o.SkipTlsVerify != nil {
opts.SkipTLSVerify = *o.SkipTlsVerify
}
if o.RequestTimeout != nil {
d, err := time.ParseDuration(*o.RequestTimeout)
if err != nil {
return opts, fmt.Errorf("target %d: parse request_timeout %q: %w", idx, *o.RequestTimeout, err)
}
opts.RequestTimeout = d
}
if o.PathRewrite != nil {
opts.PathRewrite = PathRewriteMode(*o.PathRewrite)
}
if o.CustomHeaders != nil {
opts.CustomHeaders = *o.CustomHeaders
}
return opts, nil
}
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) error {
s.Name = req.Name
s.Domain = req.Domain
s.AccountID = accountID
targets := make([]*Target, 0, len(req.Targets))
for _, apiTarget := range req.Targets {
for i, apiTarget := range req.Targets {
target := &Target{
AccountID: accountID,
Path: apiTarget.Path,
@@ -331,6 +422,13 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
if apiTarget.Host != nil {
target.Host = *apiTarget.Host
}
if apiTarget.Options != nil {
opts, err := targetOptionsFromAPI(i, apiTarget.Options)
if err != nil {
return err
}
target.Options = opts
}
targets = append(targets, target)
}
s.Targets = targets
@@ -368,6 +466,8 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
}
s.Auth.BearerAuth = bearerAuth
}
return nil
}
func (s *Service) Validate() error {
@@ -400,11 +500,113 @@ func (s *Service) Validate() error {
if target.TargetId == "" {
return fmt.Errorf("target %d has empty target_id", i)
}
if err := validateTargetOptions(i, &target.Options); err != nil {
return err
}
}
return nil
}
const (
maxRequestTimeout = 5 * time.Minute
maxCustomHeaders = 16
maxHeaderKeyLen = 128
maxHeaderValueLen = 4096
)
// httpHeaderNameRe matches valid HTTP header field names per RFC 7230 token definition.
var httpHeaderNameRe = regexp.MustCompile(`^[!#$%&'*+\-.^_` + "`" + `|~0-9A-Za-z]+$`)
// hopByHopHeaders are headers that must not be set as custom headers
// because they are connection-level and stripped by the proxy.
var hopByHopHeaders = map[string]struct{}{
"Connection": {},
"Keep-Alive": {},
"Proxy-Authenticate": {},
"Proxy-Authorization": {},
"Proxy-Connection": {},
"Te": {},
"Trailer": {},
"Transfer-Encoding": {},
"Upgrade": {},
}
// reservedHeaders are set authoritatively by the proxy or control HTTP framing
// and cannot be overridden.
var reservedHeaders = map[string]struct{}{
"Content-Length": {},
"Content-Type": {},
"Cookie": {},
"Forwarded": {},
"X-Forwarded-For": {},
"X-Forwarded-Host": {},
"X-Forwarded-Port": {},
"X-Forwarded-Proto": {},
"X-Real-Ip": {},
}
func validateTargetOptions(idx int, opts *TargetOptions) error {
if opts.PathRewrite != "" && opts.PathRewrite != PathRewritePreserve {
return fmt.Errorf("target %d: unknown path_rewrite mode %q", idx, opts.PathRewrite)
}
if opts.RequestTimeout != 0 {
if opts.RequestTimeout <= 0 {
return fmt.Errorf("target %d: request_timeout must be positive", idx)
}
if opts.RequestTimeout > maxRequestTimeout {
return fmt.Errorf("target %d: request_timeout exceeds maximum of %s", idx, maxRequestTimeout)
}
}
if err := validateCustomHeaders(idx, opts.CustomHeaders); err != nil {
return err
}
return nil
}
func validateCustomHeaders(idx int, headers map[string]string) error {
if len(headers) > maxCustomHeaders {
return fmt.Errorf("target %d: custom_headers count %d exceeds maximum of %d", idx, len(headers), maxCustomHeaders)
}
seen := make(map[string]string, len(headers))
for key, value := range headers {
if !httpHeaderNameRe.MatchString(key) {
return fmt.Errorf("target %d: custom header key %q is not a valid HTTP header name", idx, key)
}
if len(key) > maxHeaderKeyLen {
return fmt.Errorf("target %d: custom header key %q exceeds maximum length of %d", idx, key, maxHeaderKeyLen)
}
if len(value) > maxHeaderValueLen {
return fmt.Errorf("target %d: custom header %q value exceeds maximum length of %d", idx, key, maxHeaderValueLen)
}
if containsCRLF(key) || containsCRLF(value) {
return fmt.Errorf("target %d: custom header %q contains invalid characters", idx, key)
}
canonical := http.CanonicalHeaderKey(key)
if prev, ok := seen[canonical]; ok {
return fmt.Errorf("target %d: custom header keys %q and %q collide (both canonicalize to %q)", idx, prev, key, canonical)
}
seen[canonical] = key
if _, ok := hopByHopHeaders[canonical]; ok {
return fmt.Errorf("target %d: custom header %q is a hop-by-hop header and cannot be set", idx, key)
}
if _, ok := reservedHeaders[canonical]; ok {
return fmt.Errorf("target %d: custom header %q is managed by the proxy and cannot be overridden", idx, key)
}
if canonical == "Host" {
return fmt.Errorf("target %d: use pass_host_header instead of setting Host as a custom header", idx)
}
}
return nil
}
func containsCRLF(s string) bool {
return strings.ContainsAny(s, "\r\n")
}
func (s *Service) EventMeta() map[string]any {
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster, "source": s.Source, "auth": s.isAuthEnabled()}
}
@@ -417,6 +619,12 @@ func (s *Service) Copy() *Service {
targets := make([]*Target, len(s.Targets))
for i, target := range s.Targets {
targetCopy := *target
if len(target.Options.CustomHeaders) > 0 {
targetCopy.Options.CustomHeaders = make(map[string]string, len(target.Options.CustomHeaders))
for k, v := range target.Options.CustomHeaders {
targetCopy.Options.CustomHeaders[k] = v
}
}
targets[i] = &targetCopy
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -87,6 +88,188 @@ func TestValidate_MultipleTargetsOneInvalid(t *testing.T) {
assert.Contains(t, err.Error(), "empty target_id")
}
func TestValidateTargetOptions_PathRewrite(t *testing.T) {
tests := []struct {
name string
mode PathRewriteMode
wantErr string
}{
{"empty is default", "", ""},
{"preserve is valid", PathRewritePreserve, ""},
{"unknown rejected", "regex", "unknown path_rewrite mode"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.PathRewrite = tt.mode
err := rp.Validate()
if tt.wantErr == "" {
assert.NoError(t, err)
} else {
assert.ErrorContains(t, err, tt.wantErr)
}
})
}
}
func TestValidateTargetOptions_RequestTimeout(t *testing.T) {
tests := []struct {
name string
timeout time.Duration
wantErr string
}{
{"valid 30s", 30 * time.Second, ""},
{"valid 2m", 2 * time.Minute, ""},
{"zero is fine", 0, ""},
{"negative", -1 * time.Second, "must be positive"},
{"exceeds max", 10 * time.Minute, "exceeds maximum"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.RequestTimeout = tt.timeout
err := rp.Validate()
if tt.wantErr == "" {
assert.NoError(t, err)
} else {
assert.ErrorContains(t, err, tt.wantErr)
}
})
}
}
func TestValidateTargetOptions_CustomHeaders(t *testing.T) {
t.Run("valid headers", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{
"X-Custom": "value",
"X-Trace": "abc123",
}
assert.NoError(t, rp.Validate())
})
t.Run("CRLF in key", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Bad\r\nKey": "value"}
assert.ErrorContains(t, rp.Validate(), "not a valid HTTP header name")
})
t.Run("CRLF in value", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Good": "bad\nvalue"}
assert.ErrorContains(t, rp.Validate(), "invalid characters")
})
t.Run("hop-by-hop header rejected", func(t *testing.T) {
for _, h := range []string{"Connection", "Transfer-Encoding", "Keep-Alive", "Upgrade", "Proxy-Connection"} {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{h: "value"}
assert.ErrorContains(t, rp.Validate(), "hop-by-hop", "header %q should be rejected", h)
}
})
t.Run("reserved header rejected", func(t *testing.T) {
for _, h := range []string{"X-Forwarded-For", "X-Real-IP", "X-Forwarded-Proto", "X-Forwarded-Host", "X-Forwarded-Port", "Cookie", "Forwarded", "Content-Length", "Content-Type"} {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{h: "value"}
assert.ErrorContains(t, rp.Validate(), "managed by the proxy", "header %q should be rejected", h)
}
})
t.Run("Host header rejected", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{"Host": "evil.com"}
assert.ErrorContains(t, rp.Validate(), "pass_host_header")
})
t.Run("too many headers", func(t *testing.T) {
rp := validProxy()
headers := make(map[string]string, 17)
for i := range 17 {
headers[fmt.Sprintf("X-H%d", i)] = "v"
}
rp.Targets[0].Options.CustomHeaders = headers
assert.ErrorContains(t, rp.Validate(), "exceeds maximum of 16")
})
t.Run("key too long", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{strings.Repeat("X", 129): "v"}
assert.ErrorContains(t, rp.Validate(), "key")
assert.ErrorContains(t, rp.Validate(), "exceeds maximum length")
})
t.Run("value too long", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Ok": strings.Repeat("v", 4097)}
assert.ErrorContains(t, rp.Validate(), "value exceeds maximum length")
})
t.Run("duplicate canonical keys rejected", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{
"x-custom": "a",
"X-Custom": "b",
}
assert.ErrorContains(t, rp.Validate(), "collide")
})
}
func TestToProtoMapping_TargetOptions(t *testing.T) {
rp := &Service{
ID: "svc-1",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{
TargetId: "peer-1",
TargetType: TargetTypePeer,
Host: "10.0.0.1",
Port: 8080,
Protocol: "http",
Enabled: true,
Options: TargetOptions{
SkipTLSVerify: true,
RequestTimeout: 30 * time.Second,
PathRewrite: PathRewritePreserve,
CustomHeaders: map[string]string{"X-Custom": "val"},
},
},
},
}
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
require.Len(t, pm.Path, 1)
opts := pm.Path[0].Options
require.NotNil(t, opts, "options should be populated")
assert.True(t, opts.SkipTlsVerify)
assert.Equal(t, proto.PathRewriteMode_PATH_REWRITE_PRESERVE, opts.PathRewrite)
assert.Equal(t, map[string]string{"X-Custom": "val"}, opts.CustomHeaders)
require.NotNil(t, opts.RequestTimeout)
assert.Equal(t, int64(30), opts.RequestTimeout.Seconds)
}
func TestToProtoMapping_NoOptionsWhenDefault(t *testing.T) {
rp := &Service{
ID: "svc-1",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{
TargetId: "peer-1",
TargetType: TargetTypePeer,
Host: "10.0.0.1",
Port: 8080,
Protocol: "http",
Enabled: true,
},
},
}
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
require.Len(t, pm.Path, 1)
assert.Nil(t, pm.Path[0].Options, "options should be nil when all defaults")
}
func TestIsDefaultPort(t *testing.T) {
tests := []struct {
scheme string