mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-15 14:06:20 -04:00
Compare commits
4 Commits
test/proxy
...
test/conne
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ea7fb7b21 | ||
|
|
93f530637d | ||
|
|
3a84475d14 | ||
|
|
3d7368e51f |
37
.github/workflows/golang-test-linux.yml
vendored
37
.github/workflows/golang-test-linux.yml
vendored
@@ -409,19 +409,12 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: docker login for root user
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
env:
|
||||
DOCKER_USER: ${{ secrets.DOCKER_USER }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
|
||||
|
||||
- name: download mysql image
|
||||
if: matrix.store == 'mysql'
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
@@ -504,18 +497,15 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: docker login for root user
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
env:
|
||||
DOCKER_USER: ${{ secrets.DOCKER_USER }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
|
||||
- name: download mysql image
|
||||
if: matrix.store == 'mysql'
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
@@ -596,18 +586,15 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: docker login for root user
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
env:
|
||||
DOCKER_USER: ${{ secrets.DOCKER_USER }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
|
||||
- name: download mysql image
|
||||
if: matrix.store == 'mysql'
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
||||
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
||||
|
||||
FROM alpine:3.23.3
|
||||
FROM alpine:3.23.2
|
||||
# iproute2: busybox doesn't display ip rules properly
|
||||
RUN apk add --no-cache \
|
||||
bash \
|
||||
|
||||
@@ -1,194 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/signal"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
|
||||
|
||||
var (
|
||||
exposePin string
|
||||
exposePassword string
|
||||
exposeUserGroups []string
|
||||
exposeDomain string
|
||||
exposeNamePrefix string
|
||||
exposeProtocol string
|
||||
)
|
||||
|
||||
var exposeCmd = &cobra.Command{
|
||||
Use: "expose <port>",
|
||||
Short: "Expose a local port via the NetBird reverse proxy",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Example: "netbird expose --with-password safe-pass 8080",
|
||||
RunE: exposeFn,
|
||||
}
|
||||
|
||||
func init() {
|
||||
exposeCmd.Flags().StringVar(&exposePin, "with-pin", "", "Protect the exposed service with a 6-digit PIN (e.g. --with-pin 123456)")
|
||||
exposeCmd.Flags().StringVar(&exposePassword, "with-password", "", "Protect the exposed service with a password (e.g. --with-password my-secret)")
|
||||
exposeCmd.Flags().StringSliceVar(&exposeUserGroups, "with-user-groups", nil, "Restrict access to specific user groups with SSO (e.g. --with-user-groups devops,Backend)")
|
||||
exposeCmd.Flags().StringVar(&exposeDomain, "with-custom-domain", "", "Custom domain for the exposed service, must be configured to your account (e.g. --with-custom-domain myapp.example.com)")
|
||||
exposeCmd.Flags().StringVar(&exposeNamePrefix, "with-name-prefix", "", "Prefix for the generated service name (e.g. --with-name-prefix my-app)")
|
||||
exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use, http/https is supported (e.g. --protocol http)")
|
||||
}
|
||||
|
||||
func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) {
|
||||
port, err := strconv.ParseUint(portStr, 10, 32)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid port number: %s", portStr)
|
||||
}
|
||||
if port == 0 || port > 65535 {
|
||||
return 0, fmt.Errorf("invalid port number: must be between 1 and 65535")
|
||||
}
|
||||
|
||||
if !isProtocolValid(exposeProtocol) {
|
||||
return 0, fmt.Errorf("unsupported protocol %q: only 'http' or 'https' are supported", exposeProtocol)
|
||||
}
|
||||
|
||||
if exposePin != "" && !pinRegexp.MatchString(exposePin) {
|
||||
return 0, fmt.Errorf("invalid pin: must be exactly 6 digits")
|
||||
}
|
||||
|
||||
if cmd.Flags().Changed("with-password") && exposePassword == "" {
|
||||
return 0, fmt.Errorf("password cannot be empty")
|
||||
}
|
||||
|
||||
if cmd.Flags().Changed("with-user-groups") && len(exposeUserGroups) == 0 {
|
||||
return 0, fmt.Errorf("user groups cannot be empty")
|
||||
}
|
||||
|
||||
return port, nil
|
||||
}
|
||||
|
||||
func isProtocolValid(exposeProtocol string) bool {
|
||||
return strings.ToLower(exposeProtocol) == "http" || strings.ToLower(exposeProtocol) == "https"
|
||||
}
|
||||
|
||||
func exposeFn(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
|
||||
log.Errorf("failed initializing log %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Root().SilenceUsage = false
|
||||
|
||||
port, err := validateExposeFlags(cmd, args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Root().SilenceUsage = true
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigCh
|
||||
cancel()
|
||||
}()
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to daemon: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("failed to close daemon connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
protocol, err := toExposeProtocol(exposeProtocol)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stream, err := client.ExposeService(ctx, &proto.ExposeServiceRequest{
|
||||
Port: uint32(port),
|
||||
Protocol: protocol,
|
||||
Pin: exposePin,
|
||||
Password: exposePassword,
|
||||
UserGroups: exposeUserGroups,
|
||||
Domain: exposeDomain,
|
||||
NamePrefix: exposeNamePrefix,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("expose service: %w", err)
|
||||
}
|
||||
|
||||
if err := handleExposeReady(cmd, stream, port); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return waitForExposeEvents(cmd, ctx, stream)
|
||||
}
|
||||
|
||||
func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
|
||||
switch strings.ToLower(exposeProtocol) {
|
||||
case "http":
|
||||
return proto.ExposeProtocol_EXPOSE_HTTP, nil
|
||||
case "https":
|
||||
return proto.ExposeProtocol_EXPOSE_HTTPS, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported protocol %q: only 'http' or 'https' are supported", exposeProtocol)
|
||||
}
|
||||
}
|
||||
|
||||
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
|
||||
event, err := stream.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("receive expose event: %w", err)
|
||||
}
|
||||
|
||||
switch e := event.Event.(type) {
|
||||
case *proto.ExposeServiceEvent_Ready:
|
||||
cmd.Println("Service exposed successfully!")
|
||||
cmd.Printf(" Name: %s\n", e.Ready.ServiceName)
|
||||
cmd.Printf(" URL: %s\n", e.Ready.ServiceUrl)
|
||||
cmd.Printf(" Domain: %s\n", e.Ready.Domain)
|
||||
cmd.Printf(" Protocol: %s\n", exposeProtocol)
|
||||
cmd.Printf(" Port: %d\n", port)
|
||||
cmd.Println()
|
||||
cmd.Println("Press Ctrl+C to stop exposing.")
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unexpected expose event: %T", event.Event)
|
||||
}
|
||||
}
|
||||
|
||||
func waitForExposeEvents(cmd *cobra.Command, ctx context.Context, stream proto.DaemonService_ExposeServiceClient) error {
|
||||
for {
|
||||
_, err := stream.Recv()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
cmd.Println("\nService stopped.")
|
||||
//nolint:nilerr
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return fmt.Errorf("connection to daemon closed unexpectedly")
|
||||
}
|
||||
return fmt.Errorf("stream error: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
@@ -81,15 +80,6 @@ var (
|
||||
Short: "",
|
||||
Long: "",
|
||||
SilenceUsage: true,
|
||||
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(cmd.Root())
|
||||
|
||||
// Don't resolve for service commands — they create the socket, not connect to it.
|
||||
if !isServiceCmd(cmd) {
|
||||
daemonAddr = daddr.ResolveUnixDaemonAddr(daemonAddr)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -154,7 +144,6 @@ func init() {
|
||||
rootCmd.AddCommand(forwardingRulesCmd)
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
rootCmd.AddCommand(profileCmd)
|
||||
rootCmd.AddCommand(exposeCmd)
|
||||
|
||||
networksCMD.AddCommand(routesListCmd)
|
||||
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
||||
@@ -396,6 +385,7 @@ func migrateToNetbird(oldPath, newPath string) bool {
|
||||
}
|
||||
|
||||
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||
@@ -408,13 +398,3 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// isServiceCmd returns true if cmd is the "service" command or a child of it.
|
||||
func isServiceCmd(cmd *cobra.Command) bool {
|
||||
for c := cmd; c != nil; c = c.Parent() {
|
||||
if c.Name() == "service" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -5,18 +5,20 @@ package configurer
|
||||
import (
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
)
|
||||
|
||||
func openUAPI(deviceName string) (net.Listener, error) {
|
||||
uapiSock, err := ipc.UAPIOpen(deviceName)
|
||||
if err != nil {
|
||||
log.Errorf("failed to open uapi socket: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
listener, err := ipc.UAPIListen(deviceName, uapiSock)
|
||||
if err != nil {
|
||||
_ = uapiSock.Close()
|
||||
log.Errorf("failed to listen on uapi socket: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -54,14 +54,6 @@ func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder
|
||||
return wgCfg
|
||||
}
|
||||
|
||||
func NewUSPConfigurerNoUAPI(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
|
||||
return &WGUSPConfigurer{
|
||||
device: device,
|
||||
deviceName: deviceName,
|
||||
activityRecorder: activityRecorder,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
|
||||
log.Debugf("adding Wireguard private key")
|
||||
key, err := wgtypes.ParseKey(privateKey)
|
||||
|
||||
@@ -79,7 +79,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
|
||||
device.NewLogger(wgLogLevel(), "[netbird] "),
|
||||
)
|
||||
|
||||
t.configurer = configurer.NewUSPConfigurerNoUAPI(t.device, t.name, t.bind.ActivityRecorder())
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
if cErr := tunIface.Close(); cErr != nil {
|
||||
|
||||
@@ -589,6 +589,101 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func Test_UserSpaceAddAllowedIPs(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+5)
|
||||
wgIP := "10.99.99.21/30"
|
||||
wgPort := 33105
|
||||
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
opts := WGIFaceOpts{
|
||||
IFaceName: ifaceName,
|
||||
Address: wgIP,
|
||||
WGPort: wgPort,
|
||||
WGPrivKey: key,
|
||||
MTU: DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
|
||||
iface, err := NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = iface.Create()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := iface.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = iface.Up()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
keepAlive := 15 * time.Second
|
||||
initialAllowedIP := netip.MustParsePrefix("10.99.99.22/32")
|
||||
endpoint, err := net.ResolveUDPAddr("udp", "127.0.0.1:9905")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Add peer with initial endpoint and first allowed IP
|
||||
err = iface.UpdatePeer(peerPubKey, []netip.Prefix{initialAllowedIP}, keepAlive, endpoint, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Phase 1: generate 500 allowed IPs into a list
|
||||
const extraIPs = 500
|
||||
addedPrefixes := make([]netip.Prefix, 0, extraIPs)
|
||||
for i := 0; i < extraIPs; i++ {
|
||||
// Use 172.16.x.y/32 range: i encoded as two octets
|
||||
prefix := netip.MustParsePrefix(fmt.Sprintf("172.16.%d.%d/32", i/256, i%256))
|
||||
addedPrefixes = append(addedPrefixes, prefix)
|
||||
}
|
||||
|
||||
// Phase 2: iterate over the list and add each allowed IP to the peer
|
||||
phase2Start := time.Now()
|
||||
for _, prefix := range addedPrefixes {
|
||||
if addErr := iface.AddAllowedIP(peerPubKey, prefix); addErr != nil {
|
||||
t.Fatalf("failed to add allowed IP %s: %v", prefix, addErr)
|
||||
}
|
||||
}
|
||||
t.Logf("Phase 2 (add %d IPs to peer): %s", extraIPs, time.Since(phase2Start))
|
||||
|
||||
// Verify the peer has all 101 allowed IPs (1 initial + 100 added)
|
||||
peer, err := getPeer(ifaceName, peerPubKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if peer.Endpoint.String() != endpoint.String() {
|
||||
t.Fatalf("expected endpoint %s, got %s", endpoint, peer.Endpoint)
|
||||
}
|
||||
|
||||
allExpected := append([]netip.Prefix{initialAllowedIP}, addedPrefixes...)
|
||||
if len(peer.AllowedIPs) != len(allExpected) {
|
||||
t.Fatalf("expected %d allowed IPs, got %d", len(allExpected), len(peer.AllowedIPs))
|
||||
}
|
||||
|
||||
allowedIPSet := make(map[string]struct{}, len(peer.AllowedIPs))
|
||||
for _, aip := range peer.AllowedIPs {
|
||||
allowedIPSet[aip.String()] = struct{}{}
|
||||
}
|
||||
for _, expected := range allExpected {
|
||||
if _, found := allowedIPSet[expected.String()]; !found {
|
||||
t.Errorf("expected allowed IP %s not found in peer config", expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
|
||||
@@ -290,6 +290,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
if relayClient.IsDisableRelay() {
|
||||
relayURLs = []string{}
|
||||
}
|
||||
|
||||
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
|
||||
c.statusRecorder.SetRelayMgr(relayManager)
|
||||
if len(relayURLs) > 0 {
|
||||
@@ -310,6 +314,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
c.engineMutex.Lock()
|
||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager)
|
||||
engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||
engine.SetReadyChan(runningChan)
|
||||
runningChan = nil
|
||||
c.engine = engine
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
@@ -330,14 +336,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||
state.Set(StatusConnected)
|
||||
|
||||
if runningChan != nil {
|
||||
select {
|
||||
case <-runningChan:
|
||||
default:
|
||||
close(runningChan)
|
||||
}
|
||||
}
|
||||
|
||||
<-engineCtx.Done()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
//go:build !windows && !ios && !android
|
||||
|
||||
package daemonaddr
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var scanDir = "/var/run/netbird"
|
||||
|
||||
// setScanDir overrides the scan directory (used by tests).
|
||||
func setScanDir(dir string) {
|
||||
scanDir = dir
|
||||
}
|
||||
|
||||
// ResolveUnixDaemonAddr checks whether the default Unix socket exists and, if not,
|
||||
// scans /var/run/netbird/ for a single .sock file to use instead. This handles the
|
||||
// mismatch between the netbird@.service template (which places the socket under
|
||||
// /var/run/netbird/<instance>.sock) and the CLI default (/var/run/netbird.sock).
|
||||
func ResolveUnixDaemonAddr(addr string) string {
|
||||
if !strings.HasPrefix(addr, "unix://") {
|
||||
return addr
|
||||
}
|
||||
|
||||
sockPath := strings.TrimPrefix(addr, "unix://")
|
||||
if _, err := os.Stat(sockPath); err == nil {
|
||||
return addr
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(scanDir)
|
||||
if err != nil {
|
||||
return addr
|
||||
}
|
||||
|
||||
var found []string
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(e.Name(), ".sock") {
|
||||
found = append(found, filepath.Join(scanDir, e.Name()))
|
||||
}
|
||||
}
|
||||
|
||||
switch len(found) {
|
||||
case 1:
|
||||
resolved := "unix://" + found[0]
|
||||
log.Debugf("Default daemon socket not found, using discovered socket: %s", resolved)
|
||||
return resolved
|
||||
case 0:
|
||||
return addr
|
||||
default:
|
||||
log.Warnf("Default daemon socket not found and multiple sockets discovered in %s; pass --daemon-addr explicitly", scanDir)
|
||||
return addr
|
||||
}
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
//go:build windows || ios || android
|
||||
|
||||
package daemonaddr
|
||||
|
||||
// ResolveUnixDaemonAddr is a no-op on platforms that don't use Unix sockets.
|
||||
func ResolveUnixDaemonAddr(addr string) string {
|
||||
return addr
|
||||
}
|
||||
@@ -1,121 +0,0 @@
|
||||
//go:build !windows && !ios && !android
|
||||
|
||||
package daemonaddr
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// createSockFile creates a regular file with a .sock extension.
|
||||
// ResolveUnixDaemonAddr uses os.Stat (not net.Dial), so a regular file is
|
||||
// sufficient and avoids Unix socket path-length limits on macOS.
|
||||
func createSockFile(t *testing.T, path string) {
|
||||
t.Helper()
|
||||
if err := os.WriteFile(path, nil, 0o600); err != nil {
|
||||
t.Fatalf("failed to create test sock file at %s: %v", path, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnixDaemonAddr_DefaultExists(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
sock := filepath.Join(tmp, "netbird.sock")
|
||||
createSockFile(t, sock)
|
||||
|
||||
addr := "unix://" + sock
|
||||
got := ResolveUnixDaemonAddr(addr)
|
||||
if got != addr {
|
||||
t.Errorf("expected %s, got %s", addr, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnixDaemonAddr_SingleDiscovered(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
|
||||
// Default socket does not exist
|
||||
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||
|
||||
// Create a scan dir with one socket
|
||||
sd := filepath.Join(tmp, "netbird")
|
||||
if err := os.MkdirAll(sd, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
instanceSock := filepath.Join(sd, "main.sock")
|
||||
createSockFile(t, instanceSock)
|
||||
|
||||
origScanDir := scanDir
|
||||
setScanDir(sd)
|
||||
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||
|
||||
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||
expected := "unix://" + instanceSock
|
||||
if got != expected {
|
||||
t.Errorf("expected %s, got %s", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnixDaemonAddr_MultipleDiscovered(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
|
||||
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||
|
||||
sd := filepath.Join(tmp, "netbird")
|
||||
if err := os.MkdirAll(sd, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
createSockFile(t, filepath.Join(sd, "main.sock"))
|
||||
createSockFile(t, filepath.Join(sd, "other.sock"))
|
||||
|
||||
origScanDir := scanDir
|
||||
setScanDir(sd)
|
||||
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||
|
||||
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||
if got != defaultAddr {
|
||||
t.Errorf("expected original %s, got %s", defaultAddr, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnixDaemonAddr_NoSocketsFound(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
|
||||
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||
|
||||
sd := filepath.Join(tmp, "netbird")
|
||||
if err := os.MkdirAll(sd, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
origScanDir := scanDir
|
||||
setScanDir(sd)
|
||||
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||
|
||||
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||
if got != defaultAddr {
|
||||
t.Errorf("expected original %s, got %s", defaultAddr, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnixDaemonAddr_NonUnixAddr(t *testing.T) {
|
||||
addr := "tcp://127.0.0.1:41731"
|
||||
got := ResolveUnixDaemonAddr(addr)
|
||||
if got != addr {
|
||||
t.Errorf("expected %s, got %s", addr, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnixDaemonAddr_ScanDirMissing(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
|
||||
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||
|
||||
origScanDir := scanDir
|
||||
setScanDir(filepath.Join(tmp, "nonexistent"))
|
||||
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||
|
||||
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||
if got != defaultAddr {
|
||||
t.Errorf("expected original %s, got %s", defaultAddr, got)
|
||||
}
|
||||
}
|
||||
@@ -14,8 +14,6 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
@@ -24,7 +22,6 @@ import (
|
||||
|
||||
const (
|
||||
netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS"
|
||||
netbirdDNSStateKeyIndexedFormat = "State:/Network/Service/NetBird-%s-%d/DNS"
|
||||
globalIPv4State = "State:/Network/Global/IPv4"
|
||||
primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS"
|
||||
keySupplementalMatchDomains = "SupplementalMatchDomains"
|
||||
@@ -38,14 +35,6 @@ const (
|
||||
searchSuffix = "Search"
|
||||
matchSuffix = "Match"
|
||||
localSuffix = "Local"
|
||||
|
||||
// maxDomainsPerResolverEntry is the max number of domains per scutil resolver key.
|
||||
// scutil's d.add has maxArgs=101 (key + * + 99 values), so 99 is the hard cap.
|
||||
maxDomainsPerResolverEntry = 50
|
||||
|
||||
// maxDomainBytesPerResolverEntry is the max total bytes of domain strings per key.
|
||||
// scutil has an undocumented ~2048 byte value buffer; we stay well under it.
|
||||
maxDomainBytesPerResolverEntry = 1500
|
||||
)
|
||||
|
||||
type systemConfigurator struct {
|
||||
@@ -95,23 +84,28 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
||||
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
|
||||
}
|
||||
|
||||
if err := s.removeKeysContaining(matchSuffix); err != nil {
|
||||
log.Warnf("failed to remove old match keys: %v", err)
|
||||
}
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
var err error
|
||||
if len(matchDomains) != 0 {
|
||||
if err := s.addBatchedDomains(matchSuffix, matchDomains, config.ServerIP, config.ServerPort, false); err != nil {
|
||||
return fmt.Errorf("add match domains: %w", err)
|
||||
}
|
||||
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
|
||||
} else {
|
||||
log.Infof("removing match domains from the system")
|
||||
err = s.removeKeyFromSystemConfig(matchKey)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("add match domains: %w", err)
|
||||
}
|
||||
s.updateState(stateManager)
|
||||
|
||||
if err := s.removeKeysContaining(searchSuffix); err != nil {
|
||||
log.Warnf("failed to remove old search keys: %v", err)
|
||||
}
|
||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||
if len(searchDomains) != 0 {
|
||||
if err := s.addBatchedDomains(searchSuffix, searchDomains, config.ServerIP, config.ServerPort, true); err != nil {
|
||||
return fmt.Errorf("add search domains: %w", err)
|
||||
}
|
||||
err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.ServerIP, config.ServerPort)
|
||||
} else {
|
||||
log.Infof("removing search domains from the system")
|
||||
err = s.removeKeyFromSystemConfig(searchKey)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("add search domains: %w", err)
|
||||
}
|
||||
s.updateState(stateManager)
|
||||
|
||||
@@ -155,7 +149,8 @@ func (s *systemConfigurator) restoreHostDNS() error {
|
||||
|
||||
func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
|
||||
if len(s.createdKeys) == 0 {
|
||||
return s.discoverExistingKeys()
|
||||
// return defaults for startup calls
|
||||
return []string{getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix), getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)}
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(s.createdKeys))
|
||||
@@ -165,47 +160,6 @@ func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
|
||||
return keys
|
||||
}
|
||||
|
||||
// discoverExistingKeys probes scutil for all NetBird DNS keys that may exist.
|
||||
// This handles the case where createdKeys is empty (e.g., state file lost after unclean shutdown).
|
||||
func (s *systemConfigurator) discoverExistingKeys() []string {
|
||||
dnsKeys, err := getSystemDNSKeys()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get system DNS keys: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
var keys []string
|
||||
|
||||
for _, suffix := range []string{searchSuffix, matchSuffix, localSuffix} {
|
||||
key := getKeyWithInput(netbirdDNSStateKeyFormat, suffix)
|
||||
if strings.Contains(dnsKeys, key) {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
|
||||
for _, suffix := range []string{searchSuffix, matchSuffix} {
|
||||
for i := 0; ; i++ {
|
||||
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i)
|
||||
if !strings.Contains(dnsKeys, key) {
|
||||
break
|
||||
}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
// getSystemDNSKeys gets all DNS keys
|
||||
func getSystemDNSKeys() (string, error) {
|
||||
command := "list .*DNS\nquit\n"
|
||||
out, err := runSystemConfigCommand(command)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
||||
line := buildRemoveKeyOperation(key)
|
||||
_, err := runSystemConfigCommand(wrapCommand(line))
|
||||
@@ -230,11 +184,12 @@ func (s *systemConfigurator) addLocalDNS() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
domainsStr := strings.Join(s.systemDNSSettings.Domains, " ")
|
||||
if err := s.addDNSState(localKey, domainsStr, s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort, true); err != nil {
|
||||
return fmt.Errorf("add local dns state: %w", err)
|
||||
if err := s.addSearchDomains(
|
||||
localKey,
|
||||
strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort,
|
||||
); err != nil {
|
||||
return fmt.Errorf("add search domains: %w", err)
|
||||
}
|
||||
s.createdKeys[localKey] = struct{}{}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -325,77 +280,28 @@ func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
|
||||
return slices.Clone(s.origNameservers)
|
||||
}
|
||||
|
||||
// splitDomainsIntoBatches splits domains into batches respecting both element count and byte size limits.
|
||||
func splitDomainsIntoBatches(domains []string) [][]string {
|
||||
if len(domains) == 0 {
|
||||
return nil
|
||||
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
|
||||
err := s.addDNSState(key, domains, ip, port, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add dns state: %w", err)
|
||||
}
|
||||
|
||||
var batches [][]string
|
||||
var current []string
|
||||
currentBytes := 0
|
||||
log.Infof("added %d search domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
|
||||
|
||||
for _, d := range domains {
|
||||
domainLen := len(d)
|
||||
newBytes := currentBytes + domainLen
|
||||
if currentBytes > 0 {
|
||||
newBytes++ // space separator
|
||||
}
|
||||
s.createdKeys[key] = struct{}{}
|
||||
|
||||
if len(current) > 0 && (len(current) >= maxDomainsPerResolverEntry || newBytes > maxDomainBytesPerResolverEntry) {
|
||||
batches = append(batches, current)
|
||||
current = nil
|
||||
currentBytes = 0
|
||||
}
|
||||
|
||||
current = append(current, d)
|
||||
if currentBytes > 0 {
|
||||
currentBytes += 1 + domainLen
|
||||
} else {
|
||||
currentBytes = domainLen
|
||||
}
|
||||
}
|
||||
|
||||
if len(current) > 0 {
|
||||
batches = append(batches, current)
|
||||
}
|
||||
|
||||
return batches
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeKeysContaining removes all created keys that contain the given substring.
|
||||
func (s *systemConfigurator) removeKeysContaining(suffix string) error {
|
||||
var toRemove []string
|
||||
for key := range s.createdKeys {
|
||||
if strings.Contains(key, suffix) {
|
||||
toRemove = append(toRemove, key)
|
||||
}
|
||||
}
|
||||
var multiErr *multierror.Error
|
||||
for _, key := range toRemove {
|
||||
if err := s.removeKeyFromSystemConfig(key); err != nil {
|
||||
multiErr = multierror.Append(multiErr, fmt.Errorf("couldn't remove key %s: %w", key, err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(multiErr)
|
||||
}
|
||||
|
||||
// addBatchedDomains splits domains into batches and creates indexed scutil keys for each batch.
|
||||
func (s *systemConfigurator) addBatchedDomains(suffix string, domains []string, ip netip.Addr, port int, enableSearch bool) error {
|
||||
batches := splitDomainsIntoBatches(domains)
|
||||
|
||||
for i, batch := range batches {
|
||||
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i)
|
||||
domainsStr := strings.Join(batch, " ")
|
||||
|
||||
if err := s.addDNSState(key, domainsStr, ip, port, enableSearch); err != nil {
|
||||
return fmt.Errorf("add dns state for batch %d: %w", i, err)
|
||||
}
|
||||
|
||||
s.createdKeys[key] = struct{}{}
|
||||
func (s *systemConfigurator) addMatchDomains(key, domains string, dnsServer netip.Addr, port int) error {
|
||||
err := s.addDNSState(key, domains, dnsServer, port, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add dns state: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("added %d %s domains across %d resolver entries", len(domains), suffix, len(batches))
|
||||
log.Infof("added %d match domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
|
||||
|
||||
s.createdKeys[key] = struct{}{}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -458,6 +364,7 @@ func (s *systemConfigurator) flushDNSCache() error {
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("restart mDNSResponder: %w, output: %s", err, out)
|
||||
}
|
||||
|
||||
log.Info("flushed DNS cache")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,10 +3,7 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -52,22 +49,17 @@ func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
|
||||
|
||||
require.NoError(t, sm.PersistState(context.Background()))
|
||||
|
||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||
|
||||
// Collect all created keys for cleanup verification
|
||||
createdKeys := make([]string, 0, len(configurator.createdKeys))
|
||||
for key := range configurator.createdKeys {
|
||||
createdKeys = append(createdKeys, key)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for _, key := range createdKeys {
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
_ = removeTestDNSKey(key)
|
||||
}
|
||||
_ = removeTestDNSKey(localKey)
|
||||
}()
|
||||
|
||||
for _, key := range createdKeys {
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
exists, err := checkDNSKeyExists(key)
|
||||
require.NoError(t, err)
|
||||
if exists {
|
||||
@@ -91,223 +83,13 @@ func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
|
||||
err = shutdownState.Cleanup()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, key := range createdKeys {
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
exists, err := checkDNSKeyExists(key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
|
||||
}
|
||||
}
|
||||
|
||||
// generateShortDomains generates domains like a.com, b.com, ..., aa.com, ab.com, etc.
|
||||
func generateShortDomains(count int) []string {
|
||||
domains := make([]string, 0, count)
|
||||
for i := range count {
|
||||
label := ""
|
||||
n := i
|
||||
for {
|
||||
label = string(rune('a'+n%26)) + label
|
||||
n = n/26 - 1
|
||||
if n < 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
domains = append(domains, label+".com")
|
||||
}
|
||||
return domains
|
||||
}
|
||||
|
||||
// generateLongDomains generates domains like subdomain-000.department.organization-name.example.com
|
||||
func generateLongDomains(count int) []string {
|
||||
domains := make([]string, 0, count)
|
||||
for i := range count {
|
||||
domains = append(domains, fmt.Sprintf("subdomain-%03d.department.organization-name.example.com", i))
|
||||
}
|
||||
return domains
|
||||
}
|
||||
|
||||
// readDomainsFromKey reads the SupplementalMatchDomains array back from scutil for a given key.
|
||||
func readDomainsFromKey(t *testing.T, key string) []string {
|
||||
t.Helper()
|
||||
|
||||
cmd := exec.Command(scutilPath)
|
||||
cmd.Stdin = strings.NewReader(fmt.Sprintf("open\nshow %s\nquit\n", key))
|
||||
out, err := cmd.Output()
|
||||
require.NoError(t, err, "scutil show should succeed")
|
||||
|
||||
var domains []string
|
||||
inArray := false
|
||||
scanner := bufio.NewScanner(bytes.NewReader(out))
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if strings.HasPrefix(line, "SupplementalMatchDomains") && strings.Contains(line, "<array>") {
|
||||
inArray = true
|
||||
continue
|
||||
}
|
||||
if inArray {
|
||||
if line == "}" {
|
||||
break
|
||||
}
|
||||
// lines look like: "0 : a.com"
|
||||
parts := strings.SplitN(line, " : ", 2)
|
||||
if len(parts) == 2 {
|
||||
domains = append(domains, parts[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
require.NoError(t, scanner.Err())
|
||||
return domains
|
||||
}
|
||||
|
||||
func TestSplitDomainsIntoBatches(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
domains []string
|
||||
expectedCount int
|
||||
checkAllPresent bool
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
domains: nil,
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "under_limit",
|
||||
domains: generateShortDomains(10),
|
||||
expectedCount: 1,
|
||||
checkAllPresent: true,
|
||||
},
|
||||
{
|
||||
name: "at_element_limit",
|
||||
domains: generateShortDomains(50),
|
||||
expectedCount: 1,
|
||||
checkAllPresent: true,
|
||||
},
|
||||
{
|
||||
name: "over_element_limit",
|
||||
domains: generateShortDomains(51),
|
||||
expectedCount: 2,
|
||||
checkAllPresent: true,
|
||||
},
|
||||
{
|
||||
name: "triple_element_limit",
|
||||
domains: generateShortDomains(150),
|
||||
expectedCount: 3,
|
||||
checkAllPresent: true,
|
||||
},
|
||||
{
|
||||
name: "long_domains_hit_byte_limit",
|
||||
domains: generateLongDomains(50),
|
||||
checkAllPresent: true,
|
||||
},
|
||||
{
|
||||
name: "500_short_domains",
|
||||
domains: generateShortDomains(500),
|
||||
expectedCount: 10,
|
||||
checkAllPresent: true,
|
||||
},
|
||||
{
|
||||
name: "500_long_domains",
|
||||
domains: generateLongDomains(500),
|
||||
checkAllPresent: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
batches := splitDomainsIntoBatches(tc.domains)
|
||||
|
||||
if tc.expectedCount > 0 {
|
||||
assert.Len(t, batches, tc.expectedCount, "expected %d batches", tc.expectedCount)
|
||||
}
|
||||
|
||||
// Verify each batch respects limits
|
||||
for i, batch := range batches {
|
||||
assert.LessOrEqual(t, len(batch), maxDomainsPerResolverEntry,
|
||||
"batch %d exceeds element limit", i)
|
||||
|
||||
totalBytes := 0
|
||||
for j, d := range batch {
|
||||
if j > 0 {
|
||||
totalBytes++
|
||||
}
|
||||
totalBytes += len(d)
|
||||
}
|
||||
assert.LessOrEqual(t, totalBytes, maxDomainBytesPerResolverEntry,
|
||||
"batch %d exceeds byte limit (%d bytes)", i, totalBytes)
|
||||
}
|
||||
|
||||
if tc.checkAllPresent {
|
||||
var all []string
|
||||
for _, batch := range batches {
|
||||
all = append(all, batch...)
|
||||
}
|
||||
assert.Equal(t, tc.domains, all, "all domains should be present in order")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMatchDomainBatching writes increasing numbers of domains via the batching mechanism
|
||||
// and verifies all domains are readable across multiple scutil keys.
|
||||
func TestMatchDomainBatching(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping scutil integration test in short mode")
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
count int
|
||||
generator func(int) []string
|
||||
}{
|
||||
{"short_10", 10, generateShortDomains},
|
||||
{"short_50", 50, generateShortDomains},
|
||||
{"short_100", 100, generateShortDomains},
|
||||
{"short_200", 200, generateShortDomains},
|
||||
{"short_500", 500, generateShortDomains},
|
||||
{"long_10", 10, generateLongDomains},
|
||||
{"long_50", 50, generateLongDomains},
|
||||
{"long_100", 100, generateLongDomains},
|
||||
{"long_200", 200, generateLongDomains},
|
||||
{"long_500", 500, generateLongDomains},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
configurator := &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for key := range configurator.createdKeys {
|
||||
_ = removeTestDNSKey(key)
|
||||
}
|
||||
}()
|
||||
|
||||
domains := tc.generator(tc.count)
|
||||
err := configurator.addBatchedDomains(matchSuffix, domains, netip.MustParseAddr("100.64.0.1"), 53, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
batches := splitDomainsIntoBatches(domains)
|
||||
t.Logf("wrote %d domains across %d batched keys", tc.count, len(batches))
|
||||
|
||||
// Read back all domains from all batched keys
|
||||
var got []string
|
||||
for i := range batches {
|
||||
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, matchSuffix, i)
|
||||
exists, err := checkDNSKeyExists(key)
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists, "key %s should exist", key)
|
||||
|
||||
got = append(got, readDomainsFromKey(t, key)...)
|
||||
}
|
||||
|
||||
t.Logf("read back %d/%d domains from %d keys", len(got), tc.count, len(batches))
|
||||
assert.Equal(t, tc.count, len(got), "all domains should be readable")
|
||||
assert.Equal(t, domains, got, "domains should match in order")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func checkDNSKeyExists(key string) (bool, error) {
|
||||
cmd := exec.Command(scutilPath)
|
||||
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
|
||||
@@ -376,15 +158,15 @@ func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Man
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||
|
||||
cleanup := func() {
|
||||
_ = sm.Stop(context.Background())
|
||||
for key := range configurator.createdKeys {
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
_ = removeTestDNSKey(key)
|
||||
}
|
||||
// Also clean up old-format keys and local key in case they exist
|
||||
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix))
|
||||
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix))
|
||||
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix))
|
||||
}
|
||||
|
||||
return configurator, sm, cleanup
|
||||
|
||||
@@ -277,7 +277,7 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("added %d NRPT rules for %d domains", ruleIndex, len(domains))
|
||||
log.Infof("added %d NRPT rules for %d domains. Domain list: %v", ruleIndex, len(domains), domains)
|
||||
return ruleIndex, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -376,9 +376,9 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
|
||||
}
|
||||
}
|
||||
|
||||
// Flow receiver domain is intentionally excluded from caching.
|
||||
// Cloud providers may rotate the IP behind this domain; a stale cached record
|
||||
// causes TLS certificate verification failures on reconnect.
|
||||
if serverDomains.Flow != "" {
|
||||
domains = append(domains, serverDomains.Flow)
|
||||
}
|
||||
|
||||
for _, stun := range serverDomains.Stuns {
|
||||
if stun != "" {
|
||||
|
||||
@@ -391,8 +391,7 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
||||
}
|
||||
assert.Len(t, resolver.GetCachedDomains(), 3)
|
||||
|
||||
// Update with partial ServerDomains (only flow domain - flow is intentionally excluded from
|
||||
// caching to prevent TLS failures from stale records, so all existing domains are preserved)
|
||||
// Update with partial ServerDomains (only flow domain - new type, should preserve all existing)
|
||||
partialDomains := dnsconfig.ServerDomains{
|
||||
Flow: "github.com",
|
||||
}
|
||||
@@ -401,10 +400,10 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
||||
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
||||
}
|
||||
|
||||
assert.Len(t, removedDomains, 0, "Should not remove any domains when only flow domain is provided")
|
||||
assert.Len(t, removedDomains, 0, "Should not remove any domains when adding new type")
|
||||
|
||||
finalDomains := resolver.GetCachedDomains()
|
||||
assert.Len(t, finalDomains, 3, "Flow domain is not cached; all original domains should be preserved")
|
||||
assert.Len(t, finalDomains, 4, "Should have all original domains plus new flow domain")
|
||||
|
||||
domainStrings := make([]string, len(finalDomains))
|
||||
for i, d := range finalDomains {
|
||||
@@ -413,5 +412,5 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
||||
assert.Contains(t, domainStrings, "example.org")
|
||||
assert.Contains(t, domainStrings, "google.com")
|
||||
assert.Contains(t, domainStrings, "cloudflare.com")
|
||||
assert.NotContains(t, domainStrings, "github.com")
|
||||
assert.Contains(t, domainStrings, "github.com")
|
||||
}
|
||||
|
||||
@@ -351,13 +351,9 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
||||
return fmt.Errorf("upstream check call error")
|
||||
}
|
||||
|
||||
err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx))
|
||||
err := backoff.Retry(operation, exponentialBackOff)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString())
|
||||
} else {
|
||||
log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err)
|
||||
}
|
||||
log.Warn(err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -36,7 +36,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
"github.com/netbirdio/netbird/client/internal/expose"
|
||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
@@ -54,11 +53,13 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
||||
"github.com/netbirdio/netbird/client/jobexec"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
@@ -74,6 +75,7 @@ import (
|
||||
const (
|
||||
PeerConnectionTimeoutMax = 45000 // ms
|
||||
PeerConnectionTimeoutMin = 30000 // ms
|
||||
connInitLimit = 200
|
||||
disableAutoUpdate = "disabled"
|
||||
)
|
||||
|
||||
@@ -206,6 +208,7 @@ type Engine struct {
|
||||
syncRespMux sync.RWMutex
|
||||
persistSyncResponse bool
|
||||
latestSyncResponse *mgmProto.SyncResponse
|
||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||
flowManager nftypes.FlowManager
|
||||
|
||||
// auto-update
|
||||
@@ -214,6 +217,10 @@ type Engine struct {
|
||||
// WireGuard interface monitor
|
||||
wgIfaceMonitor *WGIfaceMonitor
|
||||
|
||||
// readyChan is closed when the first sync message is received from management
|
||||
readyChan chan struct{}
|
||||
readyChanOnce sync.Once
|
||||
|
||||
// shutdownWg tracks all long-running goroutines to ensure clean shutdown
|
||||
shutdownWg sync.WaitGroup
|
||||
|
||||
@@ -221,8 +228,6 @@ type Engine struct {
|
||||
|
||||
jobExecutor *jobexec.Executor
|
||||
jobExecutorWG sync.WaitGroup
|
||||
|
||||
exposeManager *expose.Manager
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -265,6 +270,7 @@ func NewEngine(
|
||||
statusRecorder: statusRecorder,
|
||||
stateManager: stateManager,
|
||||
checks: checks,
|
||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||
jobExecutor: jobexec.NewExecutor(),
|
||||
}
|
||||
@@ -273,6 +279,10 @@ func NewEngine(
|
||||
return engine
|
||||
}
|
||||
|
||||
func (e *Engine) SetReadyChan(ch chan struct{}) {
|
||||
e.readyChan = ch
|
||||
}
|
||||
|
||||
func (e *Engine) Stop() error {
|
||||
if e == nil {
|
||||
// this seems to be a very odd case but there was the possibility if the netbird down command comes before the engine is fully started
|
||||
@@ -417,7 +427,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.cancel()
|
||||
}
|
||||
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
|
||||
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
|
||||
|
||||
wgIface, err := e.newWgIface()
|
||||
if err != nil {
|
||||
@@ -800,7 +809,7 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
|
||||
|
||||
disabled := autoUpdateSettings.Version == disableAutoUpdate
|
||||
|
||||
// stop and cleanup if disabled
|
||||
// Stop and cleanup if disabled
|
||||
if e.updateManager != nil && disabled {
|
||||
log.Infof("auto-update is disabled, stopping update manager")
|
||||
e.updateManager.Stop()
|
||||
@@ -833,6 +842,13 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
defer func() {
|
||||
log.Infof("sync finished in %s", time.Since(started))
|
||||
}()
|
||||
|
||||
e.readyChanOnce.Do(func() {
|
||||
if e.readyChan != nil {
|
||||
close(e.readyChan)
|
||||
}
|
||||
})
|
||||
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
@@ -879,9 +895,11 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
// todo update signal
|
||||
}
|
||||
|
||||
uCheckTime := time.Now()
|
||||
if err := e.updateChecksIfNew(update.Checks); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("update check finished in %s", time.Since(uCheckTime))
|
||||
|
||||
nm := update.GetNetworkMap()
|
||||
if nm == nil {
|
||||
@@ -924,7 +942,9 @@ func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
|
||||
return fmt.Errorf("update relay token: %w", err)
|
||||
}
|
||||
|
||||
e.relayManager.UpdateServerURLs(update.Urls)
|
||||
if !relayClient.IsDisableRelay() {
|
||||
e.relayManager.UpdateServerURLs(update.Urls)
|
||||
}
|
||||
|
||||
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
|
||||
// We can ignore all errors because the guard will manage the reconnection retries.
|
||||
@@ -1538,6 +1558,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||
RelayManager: e.relayManager,
|
||||
SrWatcher: e.srWatcher,
|
||||
Semaphore: e.connSemaphore,
|
||||
}
|
||||
peerConn, err := peer.NewConn(config, serviceDependencies)
|
||||
if err != nil {
|
||||
@@ -1560,10 +1581,8 @@ func (e *Engine) receiveSignalEvents() {
|
||||
defer e.shutdownWg.Done()
|
||||
// connect to a stream of messages coming from the signal server
|
||||
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
|
||||
start := time.Now()
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
gotLock := time.Since(start)
|
||||
|
||||
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||
if e.ctx.Err() != nil {
|
||||
@@ -1587,8 +1606,6 @@ func (e *Engine) receiveSignalEvents() {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("receiveMSG: took %s to get lock for peer %s with session id %s", gotLock, msg.Key, offerAnswer.SessionID)
|
||||
|
||||
if msg.Body.Type == sProto.Body_OFFER {
|
||||
conn.OnRemoteOffer(*offerAnswer)
|
||||
} else {
|
||||
@@ -1822,18 +1839,11 @@ func (e *Engine) GetRouteManager() routemanager.Manager {
|
||||
return e.routeManager
|
||||
}
|
||||
|
||||
// GetFirewallManager returns the firewall manager.
|
||||
// GetFirewallManager returns the firewall manager
|
||||
func (e *Engine) GetFirewallManager() firewallManager.Manager {
|
||||
return e.firewall
|
||||
}
|
||||
|
||||
// GetExposeManager returns the expose session manager.
|
||||
func (e *Engine) GetExposeManager() *expose.Manager {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
return e.exposeManager
|
||||
}
|
||||
|
||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||
iface, err := net.InterfaceByName(ifaceName)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
package expose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const renewTimeout = 10 * time.Second
|
||||
|
||||
// Response holds the response from exposing a service.
|
||||
type Response struct {
|
||||
ServiceName string
|
||||
ServiceURL string
|
||||
Domain string
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
NamePrefix string
|
||||
Domain string
|
||||
Port uint16
|
||||
Protocol int
|
||||
Pin string
|
||||
Password string
|
||||
UserGroups []string
|
||||
}
|
||||
|
||||
type ManagementClient interface {
|
||||
CreateExpose(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error)
|
||||
RenewExpose(ctx context.Context, domain string) error
|
||||
StopExpose(ctx context.Context, domain string) error
|
||||
}
|
||||
|
||||
// Manager handles expose session lifecycle via the management client.
|
||||
type Manager struct {
|
||||
mgmClient ManagementClient
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewManager creates a new expose Manager using the given management client.
|
||||
func NewManager(ctx context.Context, mgmClient ManagementClient) *Manager {
|
||||
return &Manager{mgmClient: mgmClient, ctx: ctx}
|
||||
}
|
||||
|
||||
// Expose creates a new expose session via the management server.
|
||||
func (m *Manager) Expose(ctx context.Context, req Request) (*Response, error) {
|
||||
log.Infof("exposing service on port %d", req.Port)
|
||||
resp, err := m.mgmClient.CreateExpose(ctx, toClientExposeRequest(req))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("expose session created for %s", resp.Domain)
|
||||
|
||||
return fromClientExposeResponse(resp), nil
|
||||
}
|
||||
|
||||
func (m *Manager) KeepAlive(ctx context.Context, domain string) error {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
defer m.stop(domain)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Infof("context canceled, stopping keep alive for %s", domain)
|
||||
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
if err := m.renew(ctx, domain); err != nil {
|
||||
log.Errorf("renewing expose session for %s: %v", domain, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// renew extends the TTL of an active expose session.
|
||||
func (m *Manager) renew(ctx context.Context, domain string) error {
|
||||
renewCtx, cancel := context.WithTimeout(ctx, renewTimeout)
|
||||
defer cancel()
|
||||
return m.mgmClient.RenewExpose(renewCtx, domain)
|
||||
}
|
||||
|
||||
// stop terminates an active expose session.
|
||||
func (m *Manager) stop(domain string) {
|
||||
stopCtx, cancel := context.WithTimeout(m.ctx, renewTimeout)
|
||||
defer cancel()
|
||||
err := m.mgmClient.StopExpose(stopCtx, domain)
|
||||
if err != nil {
|
||||
log.Warnf("Failed stopping expose session for %s: %v", domain, err)
|
||||
}
|
||||
}
|
||||
@@ -1,95 +0,0 @@
|
||||
package expose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
)
|
||||
|
||||
func TestManager_Expose_Success(t *testing.T) {
|
||||
mock := &mgm.MockClient{
|
||||
CreateExposeFunc: func(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) {
|
||||
return &mgm.ExposeResponse{
|
||||
ServiceName: "my-service",
|
||||
ServiceURL: "https://my-service.example.com",
|
||||
Domain: "my-service.example.com",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
m := NewManager(context.Background(), mock)
|
||||
result, err := m.Expose(context.Background(), Request{Port: 8080})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "my-service", result.ServiceName, "service name should match")
|
||||
assert.Equal(t, "https://my-service.example.com", result.ServiceURL, "service URL should match")
|
||||
assert.Equal(t, "my-service.example.com", result.Domain, "domain should match")
|
||||
}
|
||||
|
||||
func TestManager_Expose_Error(t *testing.T) {
|
||||
mock := &mgm.MockClient{
|
||||
CreateExposeFunc: func(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) {
|
||||
return nil, errors.New("permission denied")
|
||||
},
|
||||
}
|
||||
|
||||
m := NewManager(context.Background(), mock)
|
||||
_, err := m.Expose(context.Background(), Request{Port: 8080})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied", "error should propagate")
|
||||
}
|
||||
|
||||
func TestManager_Renew_Success(t *testing.T) {
|
||||
mock := &mgm.MockClient{
|
||||
RenewExposeFunc: func(ctx context.Context, domain string) error {
|
||||
assert.Equal(t, "my-service.example.com", domain, "domain should be passed through")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
m := NewManager(context.Background(), mock)
|
||||
err := m.renew(context.Background(), "my-service.example.com")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestManager_Renew_Timeout(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
mock := &mgm.MockClient{
|
||||
RenewExposeFunc: func(ctx context.Context, domain string) error {
|
||||
return ctx.Err()
|
||||
},
|
||||
}
|
||||
|
||||
m := NewManager(ctx, mock)
|
||||
err := m.renew(ctx, "my-service.example.com")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestNewRequest(t *testing.T) {
|
||||
req := &daemonProto.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: daemonProto.ExposeProtocol_EXPOSE_HTTPS,
|
||||
Pin: "123456",
|
||||
Password: "secret",
|
||||
UserGroups: []string{"group1", "group2"},
|
||||
Domain: "custom.example.com",
|
||||
NamePrefix: "my-prefix",
|
||||
}
|
||||
|
||||
exposeReq := NewRequest(req)
|
||||
|
||||
assert.Equal(t, uint16(8080), exposeReq.Port, "port should match")
|
||||
assert.Equal(t, int(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match")
|
||||
assert.Equal(t, "123456", exposeReq.Pin, "pin should match")
|
||||
assert.Equal(t, "secret", exposeReq.Password, "password should match")
|
||||
assert.Equal(t, []string{"group1", "group2"}, exposeReq.UserGroups, "user groups should match")
|
||||
assert.Equal(t, "custom.example.com", exposeReq.Domain, "domain should match")
|
||||
assert.Equal(t, "my-prefix", exposeReq.NamePrefix, "name prefix should match")
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
package expose
|
||||
|
||||
import (
|
||||
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
)
|
||||
|
||||
// NewRequest converts a daemon ExposeServiceRequest to a management ExposeServiceRequest.
|
||||
func NewRequest(req *daemonProto.ExposeServiceRequest) *Request {
|
||||
return &Request{
|
||||
Port: uint16(req.Port),
|
||||
Protocol: int(req.Protocol),
|
||||
Pin: req.Pin,
|
||||
Password: req.Password,
|
||||
UserGroups: req.UserGroups,
|
||||
Domain: req.Domain,
|
||||
NamePrefix: req.NamePrefix,
|
||||
}
|
||||
}
|
||||
|
||||
func toClientExposeRequest(req Request) mgm.ExposeRequest {
|
||||
return mgm.ExposeRequest{
|
||||
NamePrefix: req.NamePrefix,
|
||||
Domain: req.Domain,
|
||||
Port: req.Port,
|
||||
Protocol: req.Protocol,
|
||||
Pin: req.Pin,
|
||||
Password: req.Password,
|
||||
UserGroups: req.UserGroups,
|
||||
}
|
||||
}
|
||||
|
||||
func fromClientExposeResponse(response *mgm.ExposeResponse) *Response {
|
||||
return &Response{
|
||||
ServiceName: response.ServiceName,
|
||||
Domain: response.Domain,
|
||||
ServiceURL: response.ServiceURL,
|
||||
}
|
||||
}
|
||||
@@ -22,56 +22,51 @@ func prepareFd() (int, error) {
|
||||
|
||||
func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
for {
|
||||
// Wait until fd is readable or context is cancelled, to avoid a busy-loop
|
||||
// when the routing socket returns EAGAIN (e.g. immediately after wakeup).
|
||||
if err := waitReadable(ctx, fd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := make([]byte, 2048)
|
||||
n, err := unix.Read(fd, buf)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EINTR) {
|
||||
continue
|
||||
}
|
||||
if errors.Is(err, unix.EBADF) || errors.Is(err, unix.EINVAL) {
|
||||
return fmt.Errorf("routing socket closed: %w", err)
|
||||
}
|
||||
return fmt.Errorf("read routing socket: %w", err)
|
||||
}
|
||||
|
||||
if n < unix.SizeofRtMsghdr {
|
||||
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||
continue
|
||||
}
|
||||
|
||||
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||
|
||||
switch msg.Type {
|
||||
// handle route changes
|
||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||
route, err := parseRouteMessage(buf[:n])
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
buf := make([]byte, 2048)
|
||||
n, err := unix.Read(fd, buf)
|
||||
if err != nil {
|
||||
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
||||
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
|
||||
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if n < unix.SizeofRtMsghdr {
|
||||
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||
continue
|
||||
}
|
||||
|
||||
if route.Dst.Bits() != 0 {
|
||||
continue
|
||||
}
|
||||
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||
|
||||
intf := "<nil>"
|
||||
if route.Interface != nil {
|
||||
intf = route.Interface.Name
|
||||
}
|
||||
switch msg.Type {
|
||||
case unix.RTM_ADD:
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
case unix.RTM_DELETE:
|
||||
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||
// handle route changes
|
||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||
route, err := parseRouteMessage(buf[:n])
|
||||
if err != nil {
|
||||
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if route.Dst.Bits() != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
intf := "<nil>"
|
||||
if route.Interface != nil {
|
||||
intf = route.Interface.Name
|
||||
}
|
||||
switch msg.Type {
|
||||
case unix.RTM_ADD:
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
case unix.RTM_DELETE:
|
||||
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -95,33 +90,3 @@ func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
||||
|
||||
return systemops.MsgToRoute(msg)
|
||||
}
|
||||
|
||||
// waitReadable blocks until fd has data to read, or ctx is cancelled.
|
||||
func waitReadable(ctx context.Context, fd int) error {
|
||||
var fdset unix.FdSet
|
||||
if fd < 0 || fd/unix.NFDBITS >= len(fdset.Bits) {
|
||||
return fmt.Errorf("fd %d out of range for FdSet", fd)
|
||||
}
|
||||
|
||||
for {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fdset = unix.FdSet{}
|
||||
fdset.Set(fd)
|
||||
// Use a 1-second timeout so we can re-check ctx periodically.
|
||||
tv := unix.Timeval{Sec: 1}
|
||||
n, err := unix.Select(fd+1, &fdset, nil, nil, &tv)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EINTR) {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("select on routing socket: %w", err)
|
||||
}
|
||||
if n > 0 {
|
||||
return nil
|
||||
}
|
||||
// timeout — loop back and re-check ctx
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package peer
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
@@ -24,6 +25,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
)
|
||||
|
||||
type ServiceDependencies struct {
|
||||
@@ -32,6 +34,7 @@ type ServiceDependencies struct {
|
||||
IFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
RelayManager *relayClient.Manager
|
||||
SrWatcher *guard.SRWatcher
|
||||
Semaphore *semaphoregroup.SemaphoreGroup
|
||||
PeerConnDispatcher *dispatcher.ConnectionDispatcher
|
||||
}
|
||||
|
||||
@@ -108,8 +111,9 @@ type Conn struct {
|
||||
wgProxyRelay wgproxy.Proxy
|
||||
handshaker *Handshaker
|
||||
|
||||
guard *guard.Guard
|
||||
wg sync.WaitGroup
|
||||
guard *guard.Guard
|
||||
semaphore *semaphoregroup.SemaphoreGroup
|
||||
wg sync.WaitGroup
|
||||
|
||||
// debug purpose
|
||||
dumpState *stateDump
|
||||
@@ -135,6 +139,7 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
iFaceDiscover: services.IFaceDiscover,
|
||||
relayManager: services.RelayManager,
|
||||
srWatcher: services.SrWatcher,
|
||||
semaphore: services.Semaphore,
|
||||
statusRelay: worker.NewAtomicStatus(),
|
||||
statusICE: worker.NewAtomicStatus(),
|
||||
dumpState: dumpState,
|
||||
@@ -149,10 +154,15 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
|
||||
// be used.
|
||||
func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
if err := conn.semaphore.Add(engineCtx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
if conn.opened {
|
||||
conn.semaphore.Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -163,6 +173,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
conn.semaphore.Done()
|
||||
return err
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
@@ -196,6 +207,10 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
conn.wg.Add(1)
|
||||
go func() {
|
||||
defer conn.wg.Done()
|
||||
|
||||
conn.waitInitialRandomSleepTime(conn.ctx)
|
||||
conn.semaphore.Done()
|
||||
|
||||
conn.guard.Start(conn.ctx, conn.onGuardEvent)
|
||||
}()
|
||||
conn.opened = true
|
||||
@@ -655,6 +670,19 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
|
||||
maxWait := 300
|
||||
duration := time.Duration(rand.Intn(maxWait)) * time.Millisecond
|
||||
|
||||
timeout := time.NewTimer(duration)
|
||||
defer timeout.Stop()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-timeout.C:
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) isRelayed() bool {
|
||||
switch conn.currentConnPriority {
|
||||
case conntype.Relay, conntype.ICETurn:
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
)
|
||||
|
||||
var testDispatcher = dispatcher.NewConnectionDispatcher()
|
||||
@@ -52,6 +53,7 @@ func TestConn_GetKey(t *testing.T) {
|
||||
|
||||
sd := ServiceDependencies{
|
||||
SrWatcher: swWatcher,
|
||||
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
|
||||
PeerConnDispatcher: testDispatcher,
|
||||
}
|
||||
conn, err := NewConn(connConf, sd)
|
||||
@@ -69,6 +71,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
sd := ServiceDependencies{
|
||||
StatusRecorder: NewRecorder("https://mgm"),
|
||||
SrWatcher: swWatcher,
|
||||
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
|
||||
PeerConnDispatcher: testDispatcher,
|
||||
}
|
||||
conn, err := NewConn(connConf, sd)
|
||||
@@ -107,6 +110,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
sd := ServiceDependencies{
|
||||
StatusRecorder: NewRecorder("https://mgm"),
|
||||
SrWatcher: swWatcher,
|
||||
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
|
||||
PeerConnDispatcher: testDispatcher,
|
||||
}
|
||||
conn, err := NewConn(connConf, sd)
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -125,6 +126,10 @@ func GenerateICECredentials() (string, string, error) {
|
||||
}
|
||||
|
||||
func CandidateTypes() []ice.CandidateType {
|
||||
if relayClient.IsDisableRelay() {
|
||||
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
|
||||
}
|
||||
|
||||
if hasICEForceRelayConn() {
|
||||
return []ice.CandidateType{ice.CandidateTypeRelay}
|
||||
}
|
||||
|
||||
@@ -198,7 +198,7 @@ func getConfigDirForUser(username string) (string, error) {
|
||||
|
||||
configDir := filepath.Join(DefaultConfigPathDir, username)
|
||||
if _, err := os.Stat(configDir); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(configDir, 0700); err != nil {
|
||||
if err := os.MkdirAll(configDir, 0600); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
@@ -206,15 +206,9 @@ func getConfigDirForUser(username string) (string, error) {
|
||||
return configDir, nil
|
||||
}
|
||||
|
||||
func fileExists(path string) (bool, error) {
|
||||
func fileExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
if err == nil {
|
||||
return true, nil
|
||||
}
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
return !os.IsNotExist(err)
|
||||
}
|
||||
|
||||
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
||||
@@ -641,11 +635,7 @@ func isPreSharedKeyHidden(preSharedKey *string) bool {
|
||||
|
||||
// UpdateConfig update existing configuration according to input configuration and return with the configuration
|
||||
func UpdateConfig(input ConfigInput) (*Config, error) {
|
||||
configExists, err := fileExists(input.ConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||
}
|
||||
if !configExists {
|
||||
if !fileExists(input.ConfigPath) {
|
||||
return nil, fmt.Errorf("config file %s does not exist", input.ConfigPath)
|
||||
}
|
||||
|
||||
@@ -654,11 +644,7 @@ func UpdateConfig(input ConfigInput) (*Config, error) {
|
||||
|
||||
// UpdateOrCreateConfig reads existing config or generates a new one
|
||||
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
configExists, err := fileExists(input.ConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||
}
|
||||
if !configExists {
|
||||
if !fileExists(input.ConfigPath) {
|
||||
log.Infof("generating new config %s", input.ConfigPath)
|
||||
cfg, err := createNewConfig(input)
|
||||
if err != nil {
|
||||
@@ -671,7 +657,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
if isPreSharedKeyHidden(input.PreSharedKey) {
|
||||
input.PreSharedKey = nil
|
||||
}
|
||||
err = util.EnforcePermission(input.ConfigPath)
|
||||
err := util.EnforcePermission(input.ConfigPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to enforce permission on config dir: %v", err)
|
||||
}
|
||||
@@ -798,12 +784,7 @@ func ReadConfig(configPath string) (*Config, error) {
|
||||
|
||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
func readConfig(configPath string, createIfMissing bool) (*Config, error) {
|
||||
configExists, err := fileExists(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||
}
|
||||
|
||||
if configExists {
|
||||
if fileExists(configPath) {
|
||||
err := util.EnforcePermission(configPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to enforce permission on config dir: %v", err)
|
||||
@@ -850,11 +831,7 @@ func DirectWriteOutConfig(path string, config *Config) error {
|
||||
// DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes.
|
||||
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
|
||||
func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
configExists, err := fileExists(input.ConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||
}
|
||||
if !configExists {
|
||||
if !fileExists(input.ConfigPath) {
|
||||
log.Infof("generating new config %s", input.ConfigPath)
|
||||
cfg, err := createNewConfig(input)
|
||||
if err != nil {
|
||||
|
||||
@@ -256,11 +256,7 @@ func (s *ServiceManager) AddProfile(profileName, username string) error {
|
||||
}
|
||||
|
||||
profPath := filepath.Join(configDir, profileName+".json")
|
||||
profileExists, err := fileExists(profPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if profile exists: %w", err)
|
||||
}
|
||||
if profileExists {
|
||||
if fileExists(profPath) {
|
||||
return ErrProfileAlreadyExists
|
||||
}
|
||||
|
||||
@@ -289,11 +285,7 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error {
|
||||
return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName)
|
||||
}
|
||||
profPath := filepath.Join(configDir, profileName+".json")
|
||||
profileExists, err := fileExists(profPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if profile exists: %w", err)
|
||||
}
|
||||
if !profileExists {
|
||||
if !fileExists(profPath) {
|
||||
return ErrProfileNotFound
|
||||
}
|
||||
|
||||
|
||||
@@ -20,11 +20,7 @@ func (pm *ProfileManager) GetProfileState(profileName string) (*ProfileState, er
|
||||
}
|
||||
|
||||
stateFile := filepath.Join(configDir, profileName+".state.json")
|
||||
stateFileExists, err := fileExists(stateFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if profile state file exists: %w", err)
|
||||
}
|
||||
if !stateFileExists {
|
||||
if !fileExists(stateFile) {
|
||||
return nil, errors.New("profile state file does not exist")
|
||||
}
|
||||
|
||||
|
||||
@@ -263,14 +263,8 @@ func (w *Watcher) watchPeerStatusChanges(ctx context.Context, peerKey string, pe
|
||||
case <-closer:
|
||||
return
|
||||
case routerStates := <-subscription.Events():
|
||||
select {
|
||||
case peerStateUpdate <- routerStates:
|
||||
log.Debugf("triggered route state update for Peer: %s", peerKey)
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-closer:
|
||||
return
|
||||
}
|
||||
peerStateUpdate <- routerStates
|
||||
log.Debugf("triggered route state update for Peer: %s", peerKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -351,11 +351,6 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.
|
||||
logger.Errorf("failed to update domain prefixes: %v", err)
|
||||
}
|
||||
|
||||
// Allow time for route changes to be applied before sending
|
||||
// the DNS response (relevant on iOS where setTunnelNetworkSettings
|
||||
// is asynchronous).
|
||||
waitForRouteSettlement(logger)
|
||||
|
||||
d.replaceIPsInDNSResponse(r, newPrefixes, logger)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
//go:build ios
|
||||
|
||||
package dnsinterceptor
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const routeSettleDelay = 500 * time.Millisecond
|
||||
|
||||
// waitForRouteSettlement introduces a short delay on iOS to allow
|
||||
// setTunnelNetworkSettings to apply route changes before the DNS
|
||||
// response reaches the application. Without this, the first request
|
||||
// to a newly resolved domain may bypass the tunnel.
|
||||
func waitForRouteSettlement(logger *log.Entry) {
|
||||
logger.Tracef("waiting %v for iOS route settlement", routeSettleDelay)
|
||||
time.Sleep(routeSettleDelay)
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
//go:build !ios
|
||||
|
||||
package dnsinterceptor
|
||||
|
||||
import log "github.com/sirupsen/logrus"
|
||||
|
||||
func waitForRouteSettlement(_ *log.Entry) {
|
||||
// No-op on non-iOS platforms: route changes are applied synchronously by
|
||||
// the kernel, so no settlement delay is needed before the DNS response
|
||||
// reaches the application. The delay is only required on iOS where
|
||||
// setTunnelNetworkSettings applies routes asynchronously.
|
||||
}
|
||||
@@ -1,80 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
type Agent interface {
|
||||
Up(ctx context.Context) error
|
||||
Down(ctx context.Context) error
|
||||
Status() (internal.StatusType, error)
|
||||
}
|
||||
|
||||
type SleepHandler struct {
|
||||
agent Agent
|
||||
|
||||
mu sync.Mutex
|
||||
// sleepTriggeredDown indicates whether the sleep handler triggered the last client down, to avoid unnecessary up on wake
|
||||
sleepTriggeredDown bool
|
||||
}
|
||||
|
||||
func New(agent Agent) *SleepHandler {
|
||||
return &SleepHandler{
|
||||
agent: agent,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SleepHandler) HandleWakeUp(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.sleepTriggeredDown {
|
||||
log.Info("skipping up because wasn't sleep down")
|
||||
return nil
|
||||
}
|
||||
|
||||
// avoid other wakeup runs if sleep didn't make the computer sleep
|
||||
s.sleepTriggeredDown = false
|
||||
|
||||
log.Info("running up after wake up")
|
||||
err := s.agent.Up(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("running up failed: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("running up command executed successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SleepHandler) HandleSleep(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
status, err := s.agent.Status()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if status != internal.StatusConnecting && status != internal.StatusConnected {
|
||||
log.Infof("skipping setting the agent down because status is %s", status)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Info("running down after system started sleeping")
|
||||
|
||||
if err = s.agent.Down(ctx); err != nil {
|
||||
log.Errorf("running down failed: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
s.sleepTriggeredDown = true
|
||||
|
||||
log.Info("running down executed successfully")
|
||||
return nil
|
||||
}
|
||||
@@ -1,153 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
type mockAgent struct {
|
||||
upErr error
|
||||
downErr error
|
||||
statusErr error
|
||||
status internal.StatusType
|
||||
upCalls int
|
||||
}
|
||||
|
||||
func (m *mockAgent) Up(_ context.Context) error {
|
||||
m.upCalls++
|
||||
return m.upErr
|
||||
}
|
||||
|
||||
func (m *mockAgent) Down(_ context.Context) error {
|
||||
return m.downErr
|
||||
}
|
||||
|
||||
func (m *mockAgent) Status() (internal.StatusType, error) {
|
||||
return m.status, m.statusErr
|
||||
}
|
||||
|
||||
func newHandler(status internal.StatusType) (*SleepHandler, *mockAgent) {
|
||||
agent := &mockAgent{status: status}
|
||||
return New(agent), agent
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) {
|
||||
h, agent := newHandler(internal.StatusIdle)
|
||||
|
||||
err := h.HandleWakeUp(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, agent.upCalls, "Up should not be called when flag is false")
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) {
|
||||
h, _ := newHandler(internal.StatusIdle)
|
||||
h.sleepTriggeredDown = true
|
||||
|
||||
// Even if Up fails, flag should be reset
|
||||
_ = h.HandleWakeUp(context.Background())
|
||||
|
||||
assert.False(t, h.sleepTriggeredDown, "flag must be reset before calling Up")
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_CallsUpWhenFlagSet(t *testing.T) {
|
||||
h, agent := newHandler(internal.StatusIdle)
|
||||
h.sleepTriggeredDown = true
|
||||
|
||||
err := h.HandleWakeUp(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, agent.upCalls)
|
||||
assert.False(t, h.sleepTriggeredDown)
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_ReturnsErrorFromUp(t *testing.T) {
|
||||
h, agent := newHandler(internal.StatusIdle)
|
||||
h.sleepTriggeredDown = true
|
||||
agent.upErr = errors.New("up failed")
|
||||
|
||||
err := h.HandleWakeUp(context.Background())
|
||||
|
||||
assert.ErrorIs(t, err, agent.upErr)
|
||||
assert.False(t, h.sleepTriggeredDown, "flag should still be reset even when Up fails")
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_SecondCallIsNoOp(t *testing.T) {
|
||||
h, agent := newHandler(internal.StatusIdle)
|
||||
h.sleepTriggeredDown = true
|
||||
|
||||
_ = h.HandleWakeUp(context.Background())
|
||||
err := h.HandleWakeUp(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, agent.upCalls, "second wakeup should be no-op")
|
||||
}
|
||||
|
||||
func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status internal.StatusType
|
||||
}{
|
||||
{"Idle", internal.StatusIdle},
|
||||
{"NeedsLogin", internal.StatusNeedsLogin},
|
||||
{"LoginFailed", internal.StatusLoginFailed},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h, _ := newHandler(tt.status)
|
||||
|
||||
err := h.HandleSleep(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.False(t, h.sleepTriggeredDown)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSleep_ProceedsForActiveStates(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status internal.StatusType
|
||||
}{
|
||||
{"Connecting", internal.StatusConnecting},
|
||||
{"Connected", internal.StatusConnected},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h, _ := newHandler(tt.status)
|
||||
|
||||
err := h.HandleSleep(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, h.sleepTriggeredDown)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSleep_ReturnsErrorFromStatus(t *testing.T) {
|
||||
agent := &mockAgent{statusErr: errors.New("status error")}
|
||||
h := New(agent)
|
||||
|
||||
err := h.HandleSleep(context.Background())
|
||||
|
||||
assert.ErrorIs(t, err, agent.statusErr)
|
||||
assert.False(t, h.sleepTriggeredDown)
|
||||
}
|
||||
|
||||
func TestHandleSleep_ReturnsErrorFromDown(t *testing.T) {
|
||||
agent := &mockAgent{status: internal.StatusConnected, downErr: errors.New("down failed")}
|
||||
h := New(agent)
|
||||
|
||||
err := h.HandleSleep(context.Background())
|
||||
|
||||
assert.ErrorIs(t, err, agent.downErr)
|
||||
assert.False(t, h.sleepTriggeredDown, "flag should not be set when Down fails")
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.36.6
|
||||
// protoc v6.33.3
|
||||
// protoc v6.32.1
|
||||
// source: daemon.proto
|
||||
|
||||
package proto
|
||||
@@ -88,58 +88,6 @@ func (LogLevel) EnumDescriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
type ExposeProtocol int32
|
||||
|
||||
const (
|
||||
ExposeProtocol_EXPOSE_HTTP ExposeProtocol = 0
|
||||
ExposeProtocol_EXPOSE_HTTPS ExposeProtocol = 1
|
||||
ExposeProtocol_EXPOSE_TCP ExposeProtocol = 2
|
||||
ExposeProtocol_EXPOSE_UDP ExposeProtocol = 3
|
||||
)
|
||||
|
||||
// Enum value maps for ExposeProtocol.
|
||||
var (
|
||||
ExposeProtocol_name = map[int32]string{
|
||||
0: "EXPOSE_HTTP",
|
||||
1: "EXPOSE_HTTPS",
|
||||
2: "EXPOSE_TCP",
|
||||
3: "EXPOSE_UDP",
|
||||
}
|
||||
ExposeProtocol_value = map[string]int32{
|
||||
"EXPOSE_HTTP": 0,
|
||||
"EXPOSE_HTTPS": 1,
|
||||
"EXPOSE_TCP": 2,
|
||||
"EXPOSE_UDP": 3,
|
||||
}
|
||||
)
|
||||
|
||||
func (x ExposeProtocol) Enum() *ExposeProtocol {
|
||||
p := new(ExposeProtocol)
|
||||
*p = x
|
||||
return p
|
||||
}
|
||||
|
||||
func (x ExposeProtocol) String() string {
|
||||
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
|
||||
}
|
||||
|
||||
func (ExposeProtocol) Descriptor() protoreflect.EnumDescriptor {
|
||||
return file_daemon_proto_enumTypes[1].Descriptor()
|
||||
}
|
||||
|
||||
func (ExposeProtocol) Type() protoreflect.EnumType {
|
||||
return &file_daemon_proto_enumTypes[1]
|
||||
}
|
||||
|
||||
func (x ExposeProtocol) Number() protoreflect.EnumNumber {
|
||||
return protoreflect.EnumNumber(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use ExposeProtocol.Descriptor instead.
|
||||
func (ExposeProtocol) EnumDescriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{1}
|
||||
}
|
||||
|
||||
// avoid collision with loglevel enum
|
||||
type OSLifecycleRequest_CycleType int32
|
||||
|
||||
@@ -174,11 +122,11 @@ func (x OSLifecycleRequest_CycleType) String() string {
|
||||
}
|
||||
|
||||
func (OSLifecycleRequest_CycleType) Descriptor() protoreflect.EnumDescriptor {
|
||||
return file_daemon_proto_enumTypes[2].Descriptor()
|
||||
return file_daemon_proto_enumTypes[1].Descriptor()
|
||||
}
|
||||
|
||||
func (OSLifecycleRequest_CycleType) Type() protoreflect.EnumType {
|
||||
return &file_daemon_proto_enumTypes[2]
|
||||
return &file_daemon_proto_enumTypes[1]
|
||||
}
|
||||
|
||||
func (x OSLifecycleRequest_CycleType) Number() protoreflect.EnumNumber {
|
||||
@@ -226,11 +174,11 @@ func (x SystemEvent_Severity) String() string {
|
||||
}
|
||||
|
||||
func (SystemEvent_Severity) Descriptor() protoreflect.EnumDescriptor {
|
||||
return file_daemon_proto_enumTypes[3].Descriptor()
|
||||
return file_daemon_proto_enumTypes[2].Descriptor()
|
||||
}
|
||||
|
||||
func (SystemEvent_Severity) Type() protoreflect.EnumType {
|
||||
return &file_daemon_proto_enumTypes[3]
|
||||
return &file_daemon_proto_enumTypes[2]
|
||||
}
|
||||
|
||||
func (x SystemEvent_Severity) Number() protoreflect.EnumNumber {
|
||||
@@ -281,11 +229,11 @@ func (x SystemEvent_Category) String() string {
|
||||
}
|
||||
|
||||
func (SystemEvent_Category) Descriptor() protoreflect.EnumDescriptor {
|
||||
return file_daemon_proto_enumTypes[4].Descriptor()
|
||||
return file_daemon_proto_enumTypes[3].Descriptor()
|
||||
}
|
||||
|
||||
func (SystemEvent_Category) Type() protoreflect.EnumType {
|
||||
return &file_daemon_proto_enumTypes[4]
|
||||
return &file_daemon_proto_enumTypes[3]
|
||||
}
|
||||
|
||||
func (x SystemEvent_Category) Number() protoreflect.EnumNumber {
|
||||
@@ -5652,224 +5600,6 @@ func (x *InstallerResultResponse) GetErrorMsg() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
type ExposeServiceRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Port uint32 `protobuf:"varint,1,opt,name=port,proto3" json:"port,omitempty"`
|
||||
Protocol ExposeProtocol `protobuf:"varint,2,opt,name=protocol,proto3,enum=daemon.ExposeProtocol" json:"protocol,omitempty"`
|
||||
Pin string `protobuf:"bytes,3,opt,name=pin,proto3" json:"pin,omitempty"`
|
||||
Password string `protobuf:"bytes,4,opt,name=password,proto3" json:"password,omitempty"`
|
||||
UserGroups []string `protobuf:"bytes,5,rep,name=user_groups,json=userGroups,proto3" json:"user_groups,omitempty"`
|
||||
Domain string `protobuf:"bytes,6,opt,name=domain,proto3" json:"domain,omitempty"`
|
||||
NamePrefix string `protobuf:"bytes,7,opt,name=name_prefix,json=namePrefix,proto3" json:"name_prefix,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *ExposeServiceRequest) Reset() {
|
||||
*x = ExposeServiceRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[85]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *ExposeServiceRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*ExposeServiceRequest) ProtoMessage() {}
|
||||
|
||||
func (x *ExposeServiceRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[85]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use ExposeServiceRequest.ProtoReflect.Descriptor instead.
|
||||
func (*ExposeServiceRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{85}
|
||||
}
|
||||
|
||||
func (x *ExposeServiceRequest) GetPort() uint32 {
|
||||
if x != nil {
|
||||
return x.Port
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *ExposeServiceRequest) GetProtocol() ExposeProtocol {
|
||||
if x != nil {
|
||||
return x.Protocol
|
||||
}
|
||||
return ExposeProtocol_EXPOSE_HTTP
|
||||
}
|
||||
|
||||
func (x *ExposeServiceRequest) GetPin() string {
|
||||
if x != nil {
|
||||
return x.Pin
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *ExposeServiceRequest) GetPassword() string {
|
||||
if x != nil {
|
||||
return x.Password
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *ExposeServiceRequest) GetUserGroups() []string {
|
||||
if x != nil {
|
||||
return x.UserGroups
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *ExposeServiceRequest) GetDomain() string {
|
||||
if x != nil {
|
||||
return x.Domain
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *ExposeServiceRequest) GetNamePrefix() string {
|
||||
if x != nil {
|
||||
return x.NamePrefix
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type ExposeServiceEvent struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// Types that are valid to be assigned to Event:
|
||||
//
|
||||
// *ExposeServiceEvent_Ready
|
||||
Event isExposeServiceEvent_Event `protobuf_oneof:"event"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *ExposeServiceEvent) Reset() {
|
||||
*x = ExposeServiceEvent{}
|
||||
mi := &file_daemon_proto_msgTypes[86]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *ExposeServiceEvent) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*ExposeServiceEvent) ProtoMessage() {}
|
||||
|
||||
func (x *ExposeServiceEvent) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[86]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use ExposeServiceEvent.ProtoReflect.Descriptor instead.
|
||||
func (*ExposeServiceEvent) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{86}
|
||||
}
|
||||
|
||||
func (x *ExposeServiceEvent) GetEvent() isExposeServiceEvent_Event {
|
||||
if x != nil {
|
||||
return x.Event
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *ExposeServiceEvent) GetReady() *ExposeServiceReady {
|
||||
if x != nil {
|
||||
if x, ok := x.Event.(*ExposeServiceEvent_Ready); ok {
|
||||
return x.Ready
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type isExposeServiceEvent_Event interface {
|
||||
isExposeServiceEvent_Event()
|
||||
}
|
||||
|
||||
type ExposeServiceEvent_Ready struct {
|
||||
Ready *ExposeServiceReady `protobuf:"bytes,1,opt,name=ready,proto3,oneof"`
|
||||
}
|
||||
|
||||
func (*ExposeServiceEvent_Ready) isExposeServiceEvent_Event() {}
|
||||
|
||||
type ExposeServiceReady struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
ServiceName string `protobuf:"bytes,1,opt,name=service_name,json=serviceName,proto3" json:"service_name,omitempty"`
|
||||
ServiceUrl string `protobuf:"bytes,2,opt,name=service_url,json=serviceUrl,proto3" json:"service_url,omitempty"`
|
||||
Domain string `protobuf:"bytes,3,opt,name=domain,proto3" json:"domain,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *ExposeServiceReady) Reset() {
|
||||
*x = ExposeServiceReady{}
|
||||
mi := &file_daemon_proto_msgTypes[87]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *ExposeServiceReady) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*ExposeServiceReady) ProtoMessage() {}
|
||||
|
||||
func (x *ExposeServiceReady) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[87]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use ExposeServiceReady.ProtoReflect.Descriptor instead.
|
||||
func (*ExposeServiceReady) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{87}
|
||||
}
|
||||
|
||||
func (x *ExposeServiceReady) GetServiceName() string {
|
||||
if x != nil {
|
||||
return x.ServiceName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *ExposeServiceReady) GetServiceUrl() string {
|
||||
if x != nil {
|
||||
return x.ServiceUrl
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *ExposeServiceReady) GetDomain() string {
|
||||
if x != nil {
|
||||
return x.Domain
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type PortInfo_Range struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
|
||||
@@ -5880,7 +5610,7 @@ type PortInfo_Range struct {
|
||||
|
||||
func (x *PortInfo_Range) Reset() {
|
||||
*x = PortInfo_Range{}
|
||||
mi := &file_daemon_proto_msgTypes[89]
|
||||
mi := &file_daemon_proto_msgTypes[86]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5892,7 +5622,7 @@ func (x *PortInfo_Range) String() string {
|
||||
func (*PortInfo_Range) ProtoMessage() {}
|
||||
|
||||
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[89]
|
||||
mi := &file_daemon_proto_msgTypes[86]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -6419,25 +6149,7 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\x16InstallerResultRequest\"O\n" +
|
||||
"\x17InstallerResultResponse\x12\x18\n" +
|
||||
"\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" +
|
||||
"\berrorMsg\x18\x02 \x01(\tR\berrorMsg\"\xe6\x01\n" +
|
||||
"\x14ExposeServiceRequest\x12\x12\n" +
|
||||
"\x04port\x18\x01 \x01(\rR\x04port\x122\n" +
|
||||
"\bprotocol\x18\x02 \x01(\x0e2\x16.daemon.ExposeProtocolR\bprotocol\x12\x10\n" +
|
||||
"\x03pin\x18\x03 \x01(\tR\x03pin\x12\x1a\n" +
|
||||
"\bpassword\x18\x04 \x01(\tR\bpassword\x12\x1f\n" +
|
||||
"\vuser_groups\x18\x05 \x03(\tR\n" +
|
||||
"userGroups\x12\x16\n" +
|
||||
"\x06domain\x18\x06 \x01(\tR\x06domain\x12\x1f\n" +
|
||||
"\vname_prefix\x18\a \x01(\tR\n" +
|
||||
"namePrefix\"Q\n" +
|
||||
"\x12ExposeServiceEvent\x122\n" +
|
||||
"\x05ready\x18\x01 \x01(\v2\x1a.daemon.ExposeServiceReadyH\x00R\x05readyB\a\n" +
|
||||
"\x05event\"p\n" +
|
||||
"\x12ExposeServiceReady\x12!\n" +
|
||||
"\fservice_name\x18\x01 \x01(\tR\vserviceName\x12\x1f\n" +
|
||||
"\vservice_url\x18\x02 \x01(\tR\n" +
|
||||
"serviceUrl\x12\x16\n" +
|
||||
"\x06domain\x18\x03 \x01(\tR\x06domain*b\n" +
|
||||
"\berrorMsg\x18\x02 \x01(\tR\berrorMsg*b\n" +
|
||||
"\bLogLevel\x12\v\n" +
|
||||
"\aUNKNOWN\x10\x00\x12\t\n" +
|
||||
"\x05PANIC\x10\x01\x12\t\n" +
|
||||
@@ -6446,14 +6158,7 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\x04WARN\x10\x04\x12\b\n" +
|
||||
"\x04INFO\x10\x05\x12\t\n" +
|
||||
"\x05DEBUG\x10\x06\x12\t\n" +
|
||||
"\x05TRACE\x10\a*S\n" +
|
||||
"\x0eExposeProtocol\x12\x0f\n" +
|
||||
"\vEXPOSE_HTTP\x10\x00\x12\x10\n" +
|
||||
"\fEXPOSE_HTTPS\x10\x01\x12\x0e\n" +
|
||||
"\n" +
|
||||
"EXPOSE_TCP\x10\x02\x12\x0e\n" +
|
||||
"\n" +
|
||||
"EXPOSE_UDP\x10\x032\xac\x15\n" +
|
||||
"\x05TRACE\x10\a2\xdd\x14\n" +
|
||||
"\rDaemonService\x126\n" +
|
||||
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
|
||||
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
|
||||
@@ -6492,8 +6197,7 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\x0fStartCPUProfile\x12\x1e.daemon.StartCPUProfileRequest\x1a\x1f.daemon.StartCPUProfileResponse\"\x00\x12Q\n" +
|
||||
"\x0eStopCPUProfile\x12\x1d.daemon.StopCPUProfileRequest\x1a\x1e.daemon.StopCPUProfileResponse\"\x00\x12N\n" +
|
||||
"\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" +
|
||||
"\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00\x12M\n" +
|
||||
"\rExposeService\x12\x1c.daemon.ExposeServiceRequest\x1a\x1a.daemon.ExposeServiceEvent\"\x000\x01B\bZ\x06/protob\x06proto3"
|
||||
"\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00B\bZ\x06/protob\x06proto3"
|
||||
|
||||
var (
|
||||
file_daemon_proto_rawDescOnce sync.Once
|
||||
@@ -6507,222 +6211,214 @@ func file_daemon_proto_rawDescGZIP() []byte {
|
||||
return file_daemon_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 5)
|
||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 91)
|
||||
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4)
|
||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 88)
|
||||
var file_daemon_proto_goTypes = []any{
|
||||
(LogLevel)(0), // 0: daemon.LogLevel
|
||||
(ExposeProtocol)(0), // 1: daemon.ExposeProtocol
|
||||
(OSLifecycleRequest_CycleType)(0), // 2: daemon.OSLifecycleRequest.CycleType
|
||||
(SystemEvent_Severity)(0), // 3: daemon.SystemEvent.Severity
|
||||
(SystemEvent_Category)(0), // 4: daemon.SystemEvent.Category
|
||||
(*EmptyRequest)(nil), // 5: daemon.EmptyRequest
|
||||
(*OSLifecycleRequest)(nil), // 6: daemon.OSLifecycleRequest
|
||||
(*OSLifecycleResponse)(nil), // 7: daemon.OSLifecycleResponse
|
||||
(*LoginRequest)(nil), // 8: daemon.LoginRequest
|
||||
(*LoginResponse)(nil), // 9: daemon.LoginResponse
|
||||
(*WaitSSOLoginRequest)(nil), // 10: daemon.WaitSSOLoginRequest
|
||||
(*WaitSSOLoginResponse)(nil), // 11: daemon.WaitSSOLoginResponse
|
||||
(*UpRequest)(nil), // 12: daemon.UpRequest
|
||||
(*UpResponse)(nil), // 13: daemon.UpResponse
|
||||
(*StatusRequest)(nil), // 14: daemon.StatusRequest
|
||||
(*StatusResponse)(nil), // 15: daemon.StatusResponse
|
||||
(*DownRequest)(nil), // 16: daemon.DownRequest
|
||||
(*DownResponse)(nil), // 17: daemon.DownResponse
|
||||
(*GetConfigRequest)(nil), // 18: daemon.GetConfigRequest
|
||||
(*GetConfigResponse)(nil), // 19: daemon.GetConfigResponse
|
||||
(*PeerState)(nil), // 20: daemon.PeerState
|
||||
(*LocalPeerState)(nil), // 21: daemon.LocalPeerState
|
||||
(*SignalState)(nil), // 22: daemon.SignalState
|
||||
(*ManagementState)(nil), // 23: daemon.ManagementState
|
||||
(*RelayState)(nil), // 24: daemon.RelayState
|
||||
(*NSGroupState)(nil), // 25: daemon.NSGroupState
|
||||
(*SSHSessionInfo)(nil), // 26: daemon.SSHSessionInfo
|
||||
(*SSHServerState)(nil), // 27: daemon.SSHServerState
|
||||
(*FullStatus)(nil), // 28: daemon.FullStatus
|
||||
(*ListNetworksRequest)(nil), // 29: daemon.ListNetworksRequest
|
||||
(*ListNetworksResponse)(nil), // 30: daemon.ListNetworksResponse
|
||||
(*SelectNetworksRequest)(nil), // 31: daemon.SelectNetworksRequest
|
||||
(*SelectNetworksResponse)(nil), // 32: daemon.SelectNetworksResponse
|
||||
(*IPList)(nil), // 33: daemon.IPList
|
||||
(*Network)(nil), // 34: daemon.Network
|
||||
(*PortInfo)(nil), // 35: daemon.PortInfo
|
||||
(*ForwardingRule)(nil), // 36: daemon.ForwardingRule
|
||||
(*ForwardingRulesResponse)(nil), // 37: daemon.ForwardingRulesResponse
|
||||
(*DebugBundleRequest)(nil), // 38: daemon.DebugBundleRequest
|
||||
(*DebugBundleResponse)(nil), // 39: daemon.DebugBundleResponse
|
||||
(*GetLogLevelRequest)(nil), // 40: daemon.GetLogLevelRequest
|
||||
(*GetLogLevelResponse)(nil), // 41: daemon.GetLogLevelResponse
|
||||
(*SetLogLevelRequest)(nil), // 42: daemon.SetLogLevelRequest
|
||||
(*SetLogLevelResponse)(nil), // 43: daemon.SetLogLevelResponse
|
||||
(*State)(nil), // 44: daemon.State
|
||||
(*ListStatesRequest)(nil), // 45: daemon.ListStatesRequest
|
||||
(*ListStatesResponse)(nil), // 46: daemon.ListStatesResponse
|
||||
(*CleanStateRequest)(nil), // 47: daemon.CleanStateRequest
|
||||
(*CleanStateResponse)(nil), // 48: daemon.CleanStateResponse
|
||||
(*DeleteStateRequest)(nil), // 49: daemon.DeleteStateRequest
|
||||
(*DeleteStateResponse)(nil), // 50: daemon.DeleteStateResponse
|
||||
(*SetSyncResponsePersistenceRequest)(nil), // 51: daemon.SetSyncResponsePersistenceRequest
|
||||
(*SetSyncResponsePersistenceResponse)(nil), // 52: daemon.SetSyncResponsePersistenceResponse
|
||||
(*TCPFlags)(nil), // 53: daemon.TCPFlags
|
||||
(*TracePacketRequest)(nil), // 54: daemon.TracePacketRequest
|
||||
(*TraceStage)(nil), // 55: daemon.TraceStage
|
||||
(*TracePacketResponse)(nil), // 56: daemon.TracePacketResponse
|
||||
(*SubscribeRequest)(nil), // 57: daemon.SubscribeRequest
|
||||
(*SystemEvent)(nil), // 58: daemon.SystemEvent
|
||||
(*GetEventsRequest)(nil), // 59: daemon.GetEventsRequest
|
||||
(*GetEventsResponse)(nil), // 60: daemon.GetEventsResponse
|
||||
(*SwitchProfileRequest)(nil), // 61: daemon.SwitchProfileRequest
|
||||
(*SwitchProfileResponse)(nil), // 62: daemon.SwitchProfileResponse
|
||||
(*SetConfigRequest)(nil), // 63: daemon.SetConfigRequest
|
||||
(*SetConfigResponse)(nil), // 64: daemon.SetConfigResponse
|
||||
(*AddProfileRequest)(nil), // 65: daemon.AddProfileRequest
|
||||
(*AddProfileResponse)(nil), // 66: daemon.AddProfileResponse
|
||||
(*RemoveProfileRequest)(nil), // 67: daemon.RemoveProfileRequest
|
||||
(*RemoveProfileResponse)(nil), // 68: daemon.RemoveProfileResponse
|
||||
(*ListProfilesRequest)(nil), // 69: daemon.ListProfilesRequest
|
||||
(*ListProfilesResponse)(nil), // 70: daemon.ListProfilesResponse
|
||||
(*Profile)(nil), // 71: daemon.Profile
|
||||
(*GetActiveProfileRequest)(nil), // 72: daemon.GetActiveProfileRequest
|
||||
(*GetActiveProfileResponse)(nil), // 73: daemon.GetActiveProfileResponse
|
||||
(*LogoutRequest)(nil), // 74: daemon.LogoutRequest
|
||||
(*LogoutResponse)(nil), // 75: daemon.LogoutResponse
|
||||
(*GetFeaturesRequest)(nil), // 76: daemon.GetFeaturesRequest
|
||||
(*GetFeaturesResponse)(nil), // 77: daemon.GetFeaturesResponse
|
||||
(*GetPeerSSHHostKeyRequest)(nil), // 78: daemon.GetPeerSSHHostKeyRequest
|
||||
(*GetPeerSSHHostKeyResponse)(nil), // 79: daemon.GetPeerSSHHostKeyResponse
|
||||
(*RequestJWTAuthRequest)(nil), // 80: daemon.RequestJWTAuthRequest
|
||||
(*RequestJWTAuthResponse)(nil), // 81: daemon.RequestJWTAuthResponse
|
||||
(*WaitJWTTokenRequest)(nil), // 82: daemon.WaitJWTTokenRequest
|
||||
(*WaitJWTTokenResponse)(nil), // 83: daemon.WaitJWTTokenResponse
|
||||
(*StartCPUProfileRequest)(nil), // 84: daemon.StartCPUProfileRequest
|
||||
(*StartCPUProfileResponse)(nil), // 85: daemon.StartCPUProfileResponse
|
||||
(*StopCPUProfileRequest)(nil), // 86: daemon.StopCPUProfileRequest
|
||||
(*StopCPUProfileResponse)(nil), // 87: daemon.StopCPUProfileResponse
|
||||
(*InstallerResultRequest)(nil), // 88: daemon.InstallerResultRequest
|
||||
(*InstallerResultResponse)(nil), // 89: daemon.InstallerResultResponse
|
||||
(*ExposeServiceRequest)(nil), // 90: daemon.ExposeServiceRequest
|
||||
(*ExposeServiceEvent)(nil), // 91: daemon.ExposeServiceEvent
|
||||
(*ExposeServiceReady)(nil), // 92: daemon.ExposeServiceReady
|
||||
nil, // 93: daemon.Network.ResolvedIPsEntry
|
||||
(*PortInfo_Range)(nil), // 94: daemon.PortInfo.Range
|
||||
nil, // 95: daemon.SystemEvent.MetadataEntry
|
||||
(*durationpb.Duration)(nil), // 96: google.protobuf.Duration
|
||||
(*timestamppb.Timestamp)(nil), // 97: google.protobuf.Timestamp
|
||||
(OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType
|
||||
(SystemEvent_Severity)(0), // 2: daemon.SystemEvent.Severity
|
||||
(SystemEvent_Category)(0), // 3: daemon.SystemEvent.Category
|
||||
(*EmptyRequest)(nil), // 4: daemon.EmptyRequest
|
||||
(*OSLifecycleRequest)(nil), // 5: daemon.OSLifecycleRequest
|
||||
(*OSLifecycleResponse)(nil), // 6: daemon.OSLifecycleResponse
|
||||
(*LoginRequest)(nil), // 7: daemon.LoginRequest
|
||||
(*LoginResponse)(nil), // 8: daemon.LoginResponse
|
||||
(*WaitSSOLoginRequest)(nil), // 9: daemon.WaitSSOLoginRequest
|
||||
(*WaitSSOLoginResponse)(nil), // 10: daemon.WaitSSOLoginResponse
|
||||
(*UpRequest)(nil), // 11: daemon.UpRequest
|
||||
(*UpResponse)(nil), // 12: daemon.UpResponse
|
||||
(*StatusRequest)(nil), // 13: daemon.StatusRequest
|
||||
(*StatusResponse)(nil), // 14: daemon.StatusResponse
|
||||
(*DownRequest)(nil), // 15: daemon.DownRequest
|
||||
(*DownResponse)(nil), // 16: daemon.DownResponse
|
||||
(*GetConfigRequest)(nil), // 17: daemon.GetConfigRequest
|
||||
(*GetConfigResponse)(nil), // 18: daemon.GetConfigResponse
|
||||
(*PeerState)(nil), // 19: daemon.PeerState
|
||||
(*LocalPeerState)(nil), // 20: daemon.LocalPeerState
|
||||
(*SignalState)(nil), // 21: daemon.SignalState
|
||||
(*ManagementState)(nil), // 22: daemon.ManagementState
|
||||
(*RelayState)(nil), // 23: daemon.RelayState
|
||||
(*NSGroupState)(nil), // 24: daemon.NSGroupState
|
||||
(*SSHSessionInfo)(nil), // 25: daemon.SSHSessionInfo
|
||||
(*SSHServerState)(nil), // 26: daemon.SSHServerState
|
||||
(*FullStatus)(nil), // 27: daemon.FullStatus
|
||||
(*ListNetworksRequest)(nil), // 28: daemon.ListNetworksRequest
|
||||
(*ListNetworksResponse)(nil), // 29: daemon.ListNetworksResponse
|
||||
(*SelectNetworksRequest)(nil), // 30: daemon.SelectNetworksRequest
|
||||
(*SelectNetworksResponse)(nil), // 31: daemon.SelectNetworksResponse
|
||||
(*IPList)(nil), // 32: daemon.IPList
|
||||
(*Network)(nil), // 33: daemon.Network
|
||||
(*PortInfo)(nil), // 34: daemon.PortInfo
|
||||
(*ForwardingRule)(nil), // 35: daemon.ForwardingRule
|
||||
(*ForwardingRulesResponse)(nil), // 36: daemon.ForwardingRulesResponse
|
||||
(*DebugBundleRequest)(nil), // 37: daemon.DebugBundleRequest
|
||||
(*DebugBundleResponse)(nil), // 38: daemon.DebugBundleResponse
|
||||
(*GetLogLevelRequest)(nil), // 39: daemon.GetLogLevelRequest
|
||||
(*GetLogLevelResponse)(nil), // 40: daemon.GetLogLevelResponse
|
||||
(*SetLogLevelRequest)(nil), // 41: daemon.SetLogLevelRequest
|
||||
(*SetLogLevelResponse)(nil), // 42: daemon.SetLogLevelResponse
|
||||
(*State)(nil), // 43: daemon.State
|
||||
(*ListStatesRequest)(nil), // 44: daemon.ListStatesRequest
|
||||
(*ListStatesResponse)(nil), // 45: daemon.ListStatesResponse
|
||||
(*CleanStateRequest)(nil), // 46: daemon.CleanStateRequest
|
||||
(*CleanStateResponse)(nil), // 47: daemon.CleanStateResponse
|
||||
(*DeleteStateRequest)(nil), // 48: daemon.DeleteStateRequest
|
||||
(*DeleteStateResponse)(nil), // 49: daemon.DeleteStateResponse
|
||||
(*SetSyncResponsePersistenceRequest)(nil), // 50: daemon.SetSyncResponsePersistenceRequest
|
||||
(*SetSyncResponsePersistenceResponse)(nil), // 51: daemon.SetSyncResponsePersistenceResponse
|
||||
(*TCPFlags)(nil), // 52: daemon.TCPFlags
|
||||
(*TracePacketRequest)(nil), // 53: daemon.TracePacketRequest
|
||||
(*TraceStage)(nil), // 54: daemon.TraceStage
|
||||
(*TracePacketResponse)(nil), // 55: daemon.TracePacketResponse
|
||||
(*SubscribeRequest)(nil), // 56: daemon.SubscribeRequest
|
||||
(*SystemEvent)(nil), // 57: daemon.SystemEvent
|
||||
(*GetEventsRequest)(nil), // 58: daemon.GetEventsRequest
|
||||
(*GetEventsResponse)(nil), // 59: daemon.GetEventsResponse
|
||||
(*SwitchProfileRequest)(nil), // 60: daemon.SwitchProfileRequest
|
||||
(*SwitchProfileResponse)(nil), // 61: daemon.SwitchProfileResponse
|
||||
(*SetConfigRequest)(nil), // 62: daemon.SetConfigRequest
|
||||
(*SetConfigResponse)(nil), // 63: daemon.SetConfigResponse
|
||||
(*AddProfileRequest)(nil), // 64: daemon.AddProfileRequest
|
||||
(*AddProfileResponse)(nil), // 65: daemon.AddProfileResponse
|
||||
(*RemoveProfileRequest)(nil), // 66: daemon.RemoveProfileRequest
|
||||
(*RemoveProfileResponse)(nil), // 67: daemon.RemoveProfileResponse
|
||||
(*ListProfilesRequest)(nil), // 68: daemon.ListProfilesRequest
|
||||
(*ListProfilesResponse)(nil), // 69: daemon.ListProfilesResponse
|
||||
(*Profile)(nil), // 70: daemon.Profile
|
||||
(*GetActiveProfileRequest)(nil), // 71: daemon.GetActiveProfileRequest
|
||||
(*GetActiveProfileResponse)(nil), // 72: daemon.GetActiveProfileResponse
|
||||
(*LogoutRequest)(nil), // 73: daemon.LogoutRequest
|
||||
(*LogoutResponse)(nil), // 74: daemon.LogoutResponse
|
||||
(*GetFeaturesRequest)(nil), // 75: daemon.GetFeaturesRequest
|
||||
(*GetFeaturesResponse)(nil), // 76: daemon.GetFeaturesResponse
|
||||
(*GetPeerSSHHostKeyRequest)(nil), // 77: daemon.GetPeerSSHHostKeyRequest
|
||||
(*GetPeerSSHHostKeyResponse)(nil), // 78: daemon.GetPeerSSHHostKeyResponse
|
||||
(*RequestJWTAuthRequest)(nil), // 79: daemon.RequestJWTAuthRequest
|
||||
(*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse
|
||||
(*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest
|
||||
(*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse
|
||||
(*StartCPUProfileRequest)(nil), // 83: daemon.StartCPUProfileRequest
|
||||
(*StartCPUProfileResponse)(nil), // 84: daemon.StartCPUProfileResponse
|
||||
(*StopCPUProfileRequest)(nil), // 85: daemon.StopCPUProfileRequest
|
||||
(*StopCPUProfileResponse)(nil), // 86: daemon.StopCPUProfileResponse
|
||||
(*InstallerResultRequest)(nil), // 87: daemon.InstallerResultRequest
|
||||
(*InstallerResultResponse)(nil), // 88: daemon.InstallerResultResponse
|
||||
nil, // 89: daemon.Network.ResolvedIPsEntry
|
||||
(*PortInfo_Range)(nil), // 90: daemon.PortInfo.Range
|
||||
nil, // 91: daemon.SystemEvent.MetadataEntry
|
||||
(*durationpb.Duration)(nil), // 92: google.protobuf.Duration
|
||||
(*timestamppb.Timestamp)(nil), // 93: google.protobuf.Timestamp
|
||||
}
|
||||
var file_daemon_proto_depIdxs = []int32{
|
||||
2, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
|
||||
96, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
28, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
|
||||
97, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||
97, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||
96, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||
26, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
|
||||
23, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
|
||||
22, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
|
||||
21, // 9: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
|
||||
20, // 10: daemon.FullStatus.peers:type_name -> daemon.PeerState
|
||||
24, // 11: daemon.FullStatus.relays:type_name -> daemon.RelayState
|
||||
25, // 12: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
|
||||
58, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
|
||||
27, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
|
||||
34, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
|
||||
93, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
|
||||
94, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
|
||||
35, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
|
||||
35, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
|
||||
36, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
|
||||
1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
|
||||
92, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
|
||||
93, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||
93, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||
92, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||
25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
|
||||
22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
|
||||
21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
|
||||
20, // 9: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
|
||||
19, // 10: daemon.FullStatus.peers:type_name -> daemon.PeerState
|
||||
23, // 11: daemon.FullStatus.relays:type_name -> daemon.RelayState
|
||||
24, // 12: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
|
||||
57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
|
||||
26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
|
||||
33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
|
||||
89, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
|
||||
90, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
|
||||
34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
|
||||
34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
|
||||
35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
|
||||
0, // 21: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel
|
||||
0, // 22: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
|
||||
44, // 23: daemon.ListStatesResponse.states:type_name -> daemon.State
|
||||
53, // 24: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags
|
||||
55, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
|
||||
3, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
|
||||
4, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
|
||||
97, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
|
||||
95, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
|
||||
58, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
|
||||
96, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
71, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
|
||||
1, // 33: daemon.ExposeServiceRequest.protocol:type_name -> daemon.ExposeProtocol
|
||||
92, // 34: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady
|
||||
33, // 35: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
|
||||
8, // 36: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
||||
10, // 37: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
|
||||
12, // 38: daemon.DaemonService.Up:input_type -> daemon.UpRequest
|
||||
14, // 39: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
|
||||
16, // 40: daemon.DaemonService.Down:input_type -> daemon.DownRequest
|
||||
18, // 41: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
|
||||
29, // 42: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest
|
||||
31, // 43: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest
|
||||
31, // 44: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest
|
||||
5, // 45: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest
|
||||
38, // 46: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
|
||||
40, // 47: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
|
||||
42, // 48: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
|
||||
45, // 49: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
|
||||
47, // 50: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
|
||||
49, // 51: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
|
||||
51, // 52: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest
|
||||
54, // 53: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest
|
||||
57, // 54: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest
|
||||
59, // 55: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest
|
||||
61, // 56: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest
|
||||
63, // 57: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest
|
||||
65, // 58: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest
|
||||
67, // 59: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest
|
||||
69, // 60: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest
|
||||
72, // 61: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest
|
||||
74, // 62: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
|
||||
76, // 63: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
|
||||
78, // 64: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
|
||||
80, // 65: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
|
||||
82, // 66: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
|
||||
84, // 67: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
|
||||
86, // 68: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
|
||||
6, // 69: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
|
||||
88, // 70: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
|
||||
90, // 71: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest
|
||||
9, // 72: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
||||
11, // 73: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
||||
13, // 74: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
||||
15, // 75: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
||||
17, // 76: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
||||
19, // 77: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
||||
30, // 78: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
|
||||
32, // 79: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||
32, // 80: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||
37, // 81: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
|
||||
39, // 82: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
||||
41, // 83: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
|
||||
43, // 84: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
||||
46, // 85: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
|
||||
48, // 86: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
|
||||
50, // 87: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
|
||||
52, // 88: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
|
||||
56, // 89: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
|
||||
58, // 90: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
|
||||
60, // 91: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
|
||||
62, // 92: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
|
||||
64, // 93: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
|
||||
66, // 94: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
|
||||
68, // 95: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
|
||||
70, // 96: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
|
||||
73, // 97: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
|
||||
75, // 98: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
|
||||
77, // 99: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
|
||||
79, // 100: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
|
||||
81, // 101: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
|
||||
83, // 102: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
|
||||
85, // 103: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
|
||||
87, // 104: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
|
||||
7, // 105: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
|
||||
89, // 106: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
|
||||
91, // 107: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent
|
||||
72, // [72:108] is the sub-list for method output_type
|
||||
36, // [36:72] is the sub-list for method input_type
|
||||
36, // [36:36] is the sub-list for extension type_name
|
||||
36, // [36:36] is the sub-list for extension extendee
|
||||
0, // [0:36] is the sub-list for field type_name
|
||||
43, // 23: daemon.ListStatesResponse.states:type_name -> daemon.State
|
||||
52, // 24: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags
|
||||
54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
|
||||
2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
|
||||
3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
|
||||
93, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
|
||||
91, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
|
||||
57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
|
||||
92, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
|
||||
32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
|
||||
7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
||||
9, // 35: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
|
||||
11, // 36: daemon.DaemonService.Up:input_type -> daemon.UpRequest
|
||||
13, // 37: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
|
||||
15, // 38: daemon.DaemonService.Down:input_type -> daemon.DownRequest
|
||||
17, // 39: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
|
||||
28, // 40: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest
|
||||
30, // 41: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest
|
||||
30, // 42: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest
|
||||
4, // 43: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest
|
||||
37, // 44: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
|
||||
39, // 45: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
|
||||
41, // 46: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
|
||||
44, // 47: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
|
||||
46, // 48: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
|
||||
48, // 49: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
|
||||
50, // 50: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest
|
||||
53, // 51: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest
|
||||
56, // 52: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest
|
||||
58, // 53: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest
|
||||
60, // 54: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest
|
||||
62, // 55: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest
|
||||
64, // 56: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest
|
||||
66, // 57: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest
|
||||
68, // 58: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest
|
||||
71, // 59: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest
|
||||
73, // 60: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
|
||||
75, // 61: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
|
||||
77, // 62: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
|
||||
79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
|
||||
81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
|
||||
83, // 65: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
|
||||
85, // 66: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
|
||||
5, // 67: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
|
||||
87, // 68: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
|
||||
8, // 69: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
||||
10, // 70: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
||||
12, // 71: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
||||
14, // 72: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
||||
16, // 73: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
||||
18, // 74: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
||||
29, // 75: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
|
||||
31, // 76: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||
31, // 77: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||
36, // 78: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
|
||||
38, // 79: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
||||
40, // 80: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
|
||||
42, // 81: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
||||
45, // 82: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
|
||||
47, // 83: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
|
||||
49, // 84: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
|
||||
51, // 85: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
|
||||
55, // 86: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
|
||||
57, // 87: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
|
||||
59, // 88: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
|
||||
61, // 89: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
|
||||
63, // 90: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
|
||||
65, // 91: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
|
||||
67, // 92: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
|
||||
69, // 93: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
|
||||
72, // 94: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
|
||||
74, // 95: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
|
||||
76, // 96: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
|
||||
78, // 97: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
|
||||
80, // 98: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
|
||||
82, // 99: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
|
||||
84, // 100: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
|
||||
86, // 101: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
|
||||
6, // 102: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
|
||||
88, // 103: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
|
||||
69, // [69:104] is the sub-list for method output_type
|
||||
34, // [34:69] is the sub-list for method input_type
|
||||
34, // [34:34] is the sub-list for extension type_name
|
||||
34, // [34:34] is the sub-list for extension extendee
|
||||
0, // [0:34] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_daemon_proto_init() }
|
||||
@@ -6743,16 +6439,13 @@ func file_daemon_proto_init() {
|
||||
file_daemon_proto_msgTypes[58].OneofWrappers = []any{}
|
||||
file_daemon_proto_msgTypes[69].OneofWrappers = []any{}
|
||||
file_daemon_proto_msgTypes[75].OneofWrappers = []any{}
|
||||
file_daemon_proto_msgTypes[86].OneofWrappers = []any{
|
||||
(*ExposeServiceEvent_Ready)(nil),
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
|
||||
NumEnums: 5,
|
||||
NumMessages: 91,
|
||||
NumEnums: 4,
|
||||
NumMessages: 88,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
|
||||
@@ -103,9 +103,6 @@ service DaemonService {
|
||||
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
|
||||
|
||||
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
|
||||
|
||||
// ExposeService exposes a local port via the NetBird reverse proxy
|
||||
rpc ExposeService(ExposeServiceRequest) returns (stream ExposeServiceEvent) {}
|
||||
}
|
||||
|
||||
|
||||
@@ -804,32 +801,3 @@ message InstallerResultResponse {
|
||||
bool success = 1;
|
||||
string errorMsg = 2;
|
||||
}
|
||||
|
||||
enum ExposeProtocol {
|
||||
EXPOSE_HTTP = 0;
|
||||
EXPOSE_HTTPS = 1;
|
||||
EXPOSE_TCP = 2;
|
||||
EXPOSE_UDP = 3;
|
||||
}
|
||||
|
||||
message ExposeServiceRequest {
|
||||
uint32 port = 1;
|
||||
ExposeProtocol protocol = 2;
|
||||
string pin = 3;
|
||||
string password = 4;
|
||||
repeated string user_groups = 5;
|
||||
string domain = 6;
|
||||
string name_prefix = 7;
|
||||
}
|
||||
|
||||
message ExposeServiceEvent {
|
||||
oneof event {
|
||||
ExposeServiceReady ready = 1;
|
||||
}
|
||||
}
|
||||
|
||||
message ExposeServiceReady {
|
||||
string service_name = 1;
|
||||
string service_url = 2;
|
||||
string domain = 3;
|
||||
}
|
||||
|
||||
@@ -76,8 +76,6 @@ type DaemonServiceClient interface {
|
||||
StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error)
|
||||
NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
|
||||
GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
|
||||
// ExposeService exposes a local port via the NetBird reverse proxy
|
||||
ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (DaemonService_ExposeServiceClient, error)
|
||||
}
|
||||
|
||||
type daemonServiceClient struct {
|
||||
@@ -426,38 +424,6 @@ func (c *daemonServiceClient) GetInstallerResult(ctx context.Context, in *Instal
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (DaemonService_ExposeServiceClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[1], "/daemon.DaemonService/ExposeService", opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x := &daemonServiceExposeServiceClient{stream}
|
||||
if err := x.ClientStream.SendMsg(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := x.ClientStream.CloseSend(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return x, nil
|
||||
}
|
||||
|
||||
type DaemonService_ExposeServiceClient interface {
|
||||
Recv() (*ExposeServiceEvent, error)
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
type daemonServiceExposeServiceClient struct {
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
func (x *daemonServiceExposeServiceClient) Recv() (*ExposeServiceEvent, error) {
|
||||
m := new(ExposeServiceEvent)
|
||||
if err := x.ClientStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// DaemonServiceServer is the server API for DaemonService service.
|
||||
// All implementations must embed UnimplementedDaemonServiceServer
|
||||
// for forward compatibility
|
||||
@@ -520,8 +486,6 @@ type DaemonServiceServer interface {
|
||||
StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error)
|
||||
NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
|
||||
GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
|
||||
// ExposeService exposes a local port via the NetBird reverse proxy
|
||||
ExposeService(*ExposeServiceRequest, DaemonService_ExposeServiceServer) error
|
||||
mustEmbedUnimplementedDaemonServiceServer()
|
||||
}
|
||||
|
||||
@@ -634,9 +598,6 @@ func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLi
|
||||
func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetInstallerResult not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) ExposeService(*ExposeServiceRequest, DaemonService_ExposeServiceServer) error {
|
||||
return status.Errorf(codes.Unimplemented, "method ExposeService not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
||||
|
||||
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||
@@ -1283,27 +1244,6 @@ func _DaemonService_GetInstallerResult_Handler(srv interface{}, ctx context.Cont
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_ExposeService_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
m := new(ExposeServiceRequest)
|
||||
if err := stream.RecvMsg(m); err != nil {
|
||||
return err
|
||||
}
|
||||
return srv.(DaemonServiceServer).ExposeService(m, &daemonServiceExposeServiceServer{stream})
|
||||
}
|
||||
|
||||
type DaemonService_ExposeServiceServer interface {
|
||||
Send(*ExposeServiceEvent) error
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
type daemonServiceExposeServiceServer struct {
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
func (x *daemonServiceExposeServiceServer) Send(m *ExposeServiceEvent) error {
|
||||
return x.ServerStream.SendMsg(m)
|
||||
}
|
||||
|
||||
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
@@ -1454,11 +1394,6 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
||||
Handler: _DaemonService_SubscribeEvents_Handler,
|
||||
ServerStreams: true,
|
||||
},
|
||||
{
|
||||
StreamName: "ExposeService",
|
||||
Handler: _DaemonService_ExposeService_Handler,
|
||||
ServerStreams: true,
|
||||
},
|
||||
},
|
||||
Metadata: "daemon.proto",
|
||||
}
|
||||
|
||||
77
client/server/lifecycle.go
Normal file
77
client/server/lifecycle.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
|
||||
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
|
||||
switch req.GetType() {
|
||||
case proto.OSLifecycleRequest_WAKEUP:
|
||||
return s.handleWakeUp(callerCtx)
|
||||
case proto.OSLifecycleRequest_SLEEP:
|
||||
return s.handleSleep(callerCtx)
|
||||
default:
|
||||
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
|
||||
}
|
||||
return &proto.OSLifecycleResponse{}, nil
|
||||
}
|
||||
|
||||
// handleWakeUp processes a wake-up event by triggering the Up command if the system was previously put to sleep.
|
||||
// It resets the sleep state and logs the process. Returns a response or an error if the Up command fails.
|
||||
func (s *Server) handleWakeUp(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
|
||||
if !s.sleepTriggeredDown.Load() {
|
||||
log.Info("skipping up because wasn't sleep down")
|
||||
return &proto.OSLifecycleResponse{}, nil
|
||||
}
|
||||
|
||||
// avoid other wakeup runs if sleep didn't make the computer sleep
|
||||
s.sleepTriggeredDown.Store(false)
|
||||
|
||||
log.Info("running up after wake up")
|
||||
_, err := s.Up(callerCtx, &proto.UpRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("running up failed: %v", err)
|
||||
return &proto.OSLifecycleResponse{}, err
|
||||
}
|
||||
|
||||
log.Info("running up command executed successfully")
|
||||
return &proto.OSLifecycleResponse{}, nil
|
||||
}
|
||||
|
||||
// handleSleep handles the sleep event by initiating a "down" sequence if the system is in a connected or connecting state.
|
||||
func (s *Server) handleSleep(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
|
||||
s.mutex.Lock()
|
||||
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
s.mutex.Unlock()
|
||||
return &proto.OSLifecycleResponse{}, err
|
||||
}
|
||||
|
||||
if status != internal.StatusConnecting && status != internal.StatusConnected {
|
||||
log.Infof("skipping setting the agent down because status is %s", status)
|
||||
s.mutex.Unlock()
|
||||
return &proto.OSLifecycleResponse{}, nil
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
|
||||
log.Info("running down after system started sleeping")
|
||||
|
||||
_, err = s.Down(callerCtx, &proto.DownRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("running down failed: %v", err)
|
||||
return &proto.OSLifecycleResponse{}, err
|
||||
}
|
||||
|
||||
s.sleepTriggeredDown.Store(true)
|
||||
|
||||
log.Info("running down executed successfully")
|
||||
return &proto.OSLifecycleResponse{}, nil
|
||||
}
|
||||
219
client/server/lifecycle_test.go
Normal file
219
client/server/lifecycle_test.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
func newTestServer() *Server {
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
return &Server{
|
||||
rootCtx: ctx,
|
||||
statusRecorder: peer.NewRecorder(""),
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotifyOSLifecycle_WakeUp_SkipsWhenNotSleepTriggered(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
// sleepTriggeredDown is false by default
|
||||
assert.False(t, s.sleepTriggeredDown.Load())
|
||||
|
||||
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false")
|
||||
}
|
||||
|
||||
func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusIdle(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(internal.StatusIdle)
|
||||
|
||||
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_SLEEP,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is Idle")
|
||||
}
|
||||
|
||||
func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusNeedsLogin(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(internal.StatusNeedsLogin)
|
||||
|
||||
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_SLEEP,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is NeedsLogin")
|
||||
}
|
||||
|
||||
func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnecting(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(internal.StatusConnecting)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s.actCancel = cancel
|
||||
|
||||
resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_SLEEP,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resp, "handleSleep returns not nil response on success")
|
||||
assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connecting")
|
||||
}
|
||||
|
||||
func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnected(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(internal.StatusConnected)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s.actCancel = cancel
|
||||
|
||||
resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_SLEEP,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resp, "handleSleep returns not nil response on success")
|
||||
assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connected")
|
||||
}
|
||||
|
||||
func TestNotifyOSLifecycle_WakeUp_ResetsFlag(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
// Manually set the flag to simulate prior sleep down
|
||||
s.sleepTriggeredDown.Store(true)
|
||||
|
||||
// WakeUp will try to call Up which fails without proper setup, but flag should reset first
|
||||
_, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
||||
})
|
||||
|
||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag should be reset after WakeUp attempt")
|
||||
}
|
||||
|
||||
func TestNotifyOSLifecycle_MultipleWakeUpCalls(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
// First wakeup without prior sleep - should be no-op
|
||||
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.False(t, s.sleepTriggeredDown.Load())
|
||||
|
||||
// Simulate prior sleep
|
||||
s.sleepTriggeredDown.Store(true)
|
||||
|
||||
// First wakeup after sleep - should reset flag
|
||||
_, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
||||
})
|
||||
assert.False(t, s.sleepTriggeredDown.Load())
|
||||
|
||||
// Second wakeup - should be no-op
|
||||
resp, err = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.False(t, s.sleepTriggeredDown.Load())
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
resp, err := s.handleWakeUp(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) {
|
||||
s := newTestServer()
|
||||
s.sleepTriggeredDown.Store(true)
|
||||
|
||||
// Even if Up fails, flag should be reset
|
||||
_, _ = s.handleWakeUp(context.Background())
|
||||
|
||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag must be reset before calling Up")
|
||||
}
|
||||
|
||||
func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status internal.StatusType
|
||||
}{
|
||||
{"Idle", internal.StatusIdle},
|
||||
{"NeedsLogin", internal.StatusNeedsLogin},
|
||||
{"LoginFailed", internal.StatusLoginFailed},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := newTestServer()
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(tt.status)
|
||||
|
||||
resp, err := s.handleSleep(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.False(t, s.sleepTriggeredDown.Load())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSleep_ProceedsForActiveStates(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status internal.StatusType
|
||||
}{
|
||||
{"Connecting", internal.StatusConnecting},
|
||||
{"Connected", internal.StatusConnected},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := newTestServer()
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(tt.status)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s.actCancel = cancel
|
||||
|
||||
resp, err := s.handleSleep(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.True(t, s.sleepTriggeredDown.Load())
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -21,9 +21,7 @@ import (
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/expose"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
sleephandler "github.com/netbirdio/netbird/client/internal/sleep/handler"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
@@ -87,7 +85,8 @@ type Server struct {
|
||||
profilesDisabled bool
|
||||
updateSettingsDisabled bool
|
||||
|
||||
sleepHandler *sleephandler.SleepHandler
|
||||
// sleepTriggeredDown holds a state indicated if the sleep handler triggered the last client down
|
||||
sleepTriggeredDown atomic.Bool
|
||||
|
||||
jwtCache *jwtCache
|
||||
}
|
||||
@@ -101,7 +100,7 @@ type oauthAuthFlow struct {
|
||||
|
||||
// New server instance constructor.
|
||||
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool) *Server {
|
||||
s := &Server{
|
||||
return &Server{
|
||||
rootCtx: ctx,
|
||||
logFile: logFile,
|
||||
persistSyncResponse: true,
|
||||
@@ -111,10 +110,6 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
||||
updateSettingsDisabled: updateSettingsDisabled,
|
||||
jwtCache: newJWTCache(),
|
||||
}
|
||||
agent := &serverAgent{s}
|
||||
s.sleepHandler = sleephandler.New(agent)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Server) Start() error {
|
||||
@@ -641,6 +636,8 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
|
||||
return s.waitForUp(callerCtx)
|
||||
}
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil {
|
||||
log.Warnf(errRestoreResidualState, err)
|
||||
}
|
||||
@@ -652,12 +649,10 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
// not in the progress or already successfully established connection.
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
s.mutex.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if status != internal.StatusIdle {
|
||||
s.mutex.Unlock()
|
||||
return nil, fmt.Errorf("up already in progress: current status %s", status)
|
||||
}
|
||||
|
||||
@@ -674,20 +669,17 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
s.actCancel = cancel
|
||||
|
||||
if s.config == nil {
|
||||
s.mutex.Unlock()
|
||||
return nil, fmt.Errorf("config is not defined, please call login command first")
|
||||
}
|
||||
|
||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||
if err != nil {
|
||||
s.mutex.Unlock()
|
||||
log.Errorf("failed to get active profile state: %v", err)
|
||||
return nil, fmt.Errorf("failed to get active profile state: %w", err)
|
||||
}
|
||||
|
||||
if msg != nil && msg.ProfileName != nil {
|
||||
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
|
||||
s.mutex.Unlock()
|
||||
log.Errorf("failed to switch profile: %v", err)
|
||||
return nil, fmt.Errorf("failed to switch profile: %w", err)
|
||||
}
|
||||
@@ -695,7 +687,6 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
|
||||
activeProf, err = s.profileManager.GetActiveProfileState()
|
||||
if err != nil {
|
||||
s.mutex.Unlock()
|
||||
log.Errorf("failed to get active profile state: %v", err)
|
||||
return nil, fmt.Errorf("failed to get active profile state: %w", err)
|
||||
}
|
||||
@@ -704,7 +695,6 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
|
||||
config, _, err := s.getConfig(activeProf)
|
||||
if err != nil {
|
||||
s.mutex.Unlock()
|
||||
log.Errorf("failed to get active profile config: %v", err)
|
||||
return nil, fmt.Errorf("failed to get active profile config: %w", err)
|
||||
}
|
||||
@@ -723,7 +713,6 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
}
|
||||
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan)
|
||||
|
||||
s.mutex.Unlock()
|
||||
return s.waitForUp(callerCtx)
|
||||
}
|
||||
|
||||
@@ -1323,60 +1312,6 @@ func (s *Server) WaitJWTToken(
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExposeService exposes a local port via the NetBird reverse proxy.
|
||||
func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.DaemonService_ExposeServiceServer) error {
|
||||
s.mutex.Lock()
|
||||
if !s.clientRunning {
|
||||
s.mutex.Unlock()
|
||||
return gstatus.Errorf(codes.FailedPrecondition, "client is not running, run 'netbird up' first")
|
||||
}
|
||||
connectClient := s.connectClient
|
||||
s.mutex.Unlock()
|
||||
|
||||
if connectClient == nil {
|
||||
return gstatus.Errorf(codes.FailedPrecondition, "client not initialized")
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return gstatus.Errorf(codes.FailedPrecondition, "engine not initialized")
|
||||
}
|
||||
|
||||
mgr := engine.GetExposeManager()
|
||||
if mgr == nil {
|
||||
return gstatus.Errorf(codes.Internal, "expose manager not available")
|
||||
}
|
||||
|
||||
ctx := srv.Context()
|
||||
|
||||
exposeCtx, exposeCancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer exposeCancel()
|
||||
|
||||
mgmReq := expose.NewRequest(req)
|
||||
result, err := mgr.Expose(exposeCtx, *mgmReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := srv.Send(&proto.ExposeServiceEvent{
|
||||
Event: &proto.ExposeServiceEvent_Ready{
|
||||
Ready: &proto.ExposeServiceReady{
|
||||
ServiceName: result.ServiceName,
|
||||
ServiceUrl: result.ServiceURL,
|
||||
Domain: result.Domain,
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = mgr.KeepAlive(ctx, result.Domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isUnixRunningDesktop() bool {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
return false
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// serverAgent adapts Server to the handler.Agent and handler.StatusChecker interfaces
|
||||
type serverAgent struct {
|
||||
s *Server
|
||||
}
|
||||
|
||||
func (a *serverAgent) Up(ctx context.Context) error {
|
||||
_, err := a.s.Up(ctx, &proto.UpRequest{})
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *serverAgent) Down(ctx context.Context) error {
|
||||
_, err := a.s.Down(ctx, &proto.DownRequest{})
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *serverAgent) Status() (internal.StatusType, error) {
|
||||
return internal.CtxGetState(a.s.rootCtx).Status()
|
||||
}
|
||||
|
||||
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
|
||||
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
|
||||
switch req.GetType() {
|
||||
case proto.OSLifecycleRequest_WAKEUP:
|
||||
if err := s.sleepHandler.HandleWakeUp(callerCtx); err != nil {
|
||||
return &proto.OSLifecycleResponse{}, err
|
||||
}
|
||||
case proto.OSLifecycleRequest_SLEEP:
|
||||
if err := s.sleepHandler.HandleSleep(callerCtx); err != nil {
|
||||
return &proto.OSLifecycleResponse{}, err
|
||||
}
|
||||
default:
|
||||
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
|
||||
}
|
||||
return &proto.OSLifecycleResponse{}, nil
|
||||
}
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/daemonaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
@@ -269,7 +268,7 @@ func getDefaultDaemonAddr() string {
|
||||
if runtime.GOOS == "windows" {
|
||||
return DefaultDaemonAddrWindows
|
||||
}
|
||||
return daemonaddr.ResolveUnixDaemonAddr(DefaultDaemonAddr)
|
||||
return DefaultDaemonAddr
|
||||
}
|
||||
|
||||
// DialOptions contains options for SSH connections
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -71,8 +70,6 @@ type ServerConfig struct {
|
||||
DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"`
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
Store StoreConfig `yaml:"store"`
|
||||
ActivityStore StoreConfig `yaml:"activityStore"`
|
||||
AuthStore StoreConfig `yaml:"authStore"`
|
||||
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
|
||||
}
|
||||
|
||||
@@ -173,8 +170,7 @@ type RelaysConfig struct {
|
||||
type StoreConfig struct {
|
||||
Engine string `yaml:"engine"`
|
||||
EncryptionKey string `yaml:"encryptionKey"`
|
||||
DSN string `yaml:"dsn"` // Connection string for postgres or mysql engines
|
||||
File string `yaml:"file"` // SQLite database file path (optional, defaults to dataDir)
|
||||
DSN string `yaml:"dsn"` // Connection string for postgres or mysql engines
|
||||
}
|
||||
|
||||
// ReverseProxyConfig contains reverse proxy settings
|
||||
@@ -536,74 +532,6 @@ func stripSignalProtocol(uri string) string {
|
||||
return uri
|
||||
}
|
||||
|
||||
func buildRelayConfig(relays RelaysConfig) (*nbconfig.Relay, error) {
|
||||
var ttl time.Duration
|
||||
if relays.CredentialsTTL != "" {
|
||||
var err error
|
||||
ttl, err = time.ParseDuration(relays.CredentialsTTL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid relay credentials TTL %q: %w", relays.CredentialsTTL, err)
|
||||
}
|
||||
}
|
||||
return &nbconfig.Relay{
|
||||
Addresses: relays.Addresses,
|
||||
CredentialsTTL: util.Duration{Duration: ttl},
|
||||
Secret: relays.Secret,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildEmbeddedIdPConfig builds the embedded IdP configuration.
|
||||
// authStore overrides auth.storage when set.
|
||||
func (c *CombinedConfig) buildEmbeddedIdPConfig(mgmt ManagementConfig) (*idp.EmbeddedIdPConfig, error) {
|
||||
authStorageType := mgmt.Auth.Storage.Type
|
||||
authStorageDSN := c.Server.AuthStore.DSN
|
||||
if c.Server.AuthStore.Engine != "" {
|
||||
authStorageType = c.Server.AuthStore.Engine
|
||||
}
|
||||
if authStorageType == "" {
|
||||
authStorageType = "sqlite3"
|
||||
}
|
||||
authStorageFile := ""
|
||||
if authStorageType == "postgres" {
|
||||
if authStorageDSN == "" {
|
||||
return nil, fmt.Errorf("authStore.dsn is required when authStore.engine is postgres")
|
||||
}
|
||||
} else {
|
||||
authStorageFile = path.Join(mgmt.DataDir, "idp.db")
|
||||
if c.Server.AuthStore.File != "" {
|
||||
authStorageFile = c.Server.AuthStore.File
|
||||
if !filepath.IsAbs(authStorageFile) {
|
||||
authStorageFile = filepath.Join(mgmt.DataDir, authStorageFile)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cfg := &idp.EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: mgmt.Auth.Issuer,
|
||||
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
|
||||
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
|
||||
Storage: idp.EmbeddedStorageConfig{
|
||||
Type: authStorageType,
|
||||
Config: idp.EmbeddedStorageTypeConfig{
|
||||
File: authStorageFile,
|
||||
DSN: authStorageDSN,
|
||||
},
|
||||
},
|
||||
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
|
||||
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
|
||||
}
|
||||
|
||||
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {
|
||||
cfg.Owner = &idp.OwnerConfig{
|
||||
Email: mgmt.Auth.Owner.Email,
|
||||
Hash: mgmt.Auth.Owner.Password,
|
||||
}
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// ToManagementConfig converts CombinedConfig to management server config
|
||||
func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
|
||||
mgmt := c.Management
|
||||
@@ -622,11 +550,19 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
|
||||
// Build relay config
|
||||
var relayConfig *nbconfig.Relay
|
||||
if len(mgmt.Relays.Addresses) > 0 || mgmt.Relays.Secret != "" {
|
||||
relay, err := buildRelayConfig(mgmt.Relays)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
var ttl time.Duration
|
||||
if mgmt.Relays.CredentialsTTL != "" {
|
||||
var err error
|
||||
ttl, err = time.ParseDuration(mgmt.Relays.CredentialsTTL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid relay credentials TTL %q: %w", mgmt.Relays.CredentialsTTL, err)
|
||||
}
|
||||
}
|
||||
relayConfig = &nbconfig.Relay{
|
||||
Addresses: mgmt.Relays.Addresses,
|
||||
CredentialsTTL: util.Duration{Duration: ttl},
|
||||
Secret: mgmt.Relays.Secret,
|
||||
}
|
||||
relayConfig = relay
|
||||
}
|
||||
|
||||
// Build signal config
|
||||
@@ -662,9 +598,31 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
|
||||
httpConfig := &nbconfig.HttpServerConfig{}
|
||||
|
||||
// Build embedded IDP config (always enabled in combined server)
|
||||
embeddedIdP, err := c.buildEmbeddedIdPConfig(mgmt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
storageFile := mgmt.Auth.Storage.File
|
||||
if storageFile == "" {
|
||||
storageFile = path.Join(mgmt.DataDir, "idp.db")
|
||||
}
|
||||
|
||||
embeddedIdP := &idp.EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: mgmt.Auth.Issuer,
|
||||
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
|
||||
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
|
||||
Storage: idp.EmbeddedStorageConfig{
|
||||
Type: mgmt.Auth.Storage.Type,
|
||||
Config: idp.EmbeddedStorageTypeConfig{
|
||||
File: storageFile,
|
||||
},
|
||||
},
|
||||
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
|
||||
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
|
||||
}
|
||||
|
||||
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {
|
||||
embeddedIdP.Owner = &idp.OwnerConfig{
|
||||
Email: mgmt.Auth.Owner.Email,
|
||||
Hash: mgmt.Auth.Owner.Password, // Will be hashed if plain text
|
||||
}
|
||||
}
|
||||
|
||||
// Set HTTP config fields for embedded IDP
|
||||
|
||||
@@ -140,23 +140,6 @@ func initializeConfig() error {
|
||||
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
||||
}
|
||||
}
|
||||
if file := config.Server.Store.File; file != "" {
|
||||
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
|
||||
}
|
||||
|
||||
if engine := config.Server.ActivityStore.Engine; engine != "" {
|
||||
engineLower := strings.ToLower(engine)
|
||||
if engineLower == "postgres" && config.Server.ActivityStore.DSN == "" {
|
||||
return fmt.Errorf("activityStore.dsn is required when activityStore.engine is postgres")
|
||||
}
|
||||
os.Setenv("NB_ACTIVITY_EVENT_STORE_ENGINE", engineLower)
|
||||
if dsn := config.Server.ActivityStore.DSN; dsn != "" {
|
||||
os.Setenv("NB_ACTIVITY_EVENT_POSTGRES_DSN", dsn)
|
||||
}
|
||||
}
|
||||
if file := config.Server.ActivityStore.File; file != "" {
|
||||
os.Setenv("NB_ACTIVITY_EVENT_SQLITE_FILE", file)
|
||||
}
|
||||
|
||||
log.Infof("Starting combined NetBird server")
|
||||
logConfig(config)
|
||||
@@ -685,11 +668,8 @@ func logEnvVars() {
|
||||
if strings.HasPrefix(env, "NB_") {
|
||||
key, _, _ := strings.Cut(env, "=")
|
||||
value := os.Getenv(key)
|
||||
keyLower := strings.ToLower(key)
|
||||
if strings.Contains(keyLower, "secret") || strings.Contains(keyLower, "key") || strings.Contains(keyLower, "password") {
|
||||
if strings.Contains(strings.ToLower(key), "secret") || strings.Contains(strings.ToLower(key), "key") || strings.Contains(strings.ToLower(key), "password") {
|
||||
value = maskSecret(value)
|
||||
} else if strings.Contains(keyLower, "dsn") {
|
||||
value = maskDSNPassword(value)
|
||||
}
|
||||
log.Infof(" %s=%s", key, value)
|
||||
found = true
|
||||
|
||||
@@ -42,9 +42,6 @@ func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Sto
|
||||
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
||||
}
|
||||
}
|
||||
if file := cfg.Server.Store.File; file != "" {
|
||||
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
|
||||
}
|
||||
|
||||
datadir := cfg.Management.DataDir
|
||||
engine := types.Engine(cfg.Management.Store.Engine)
|
||||
|
||||
111
combined/config-simple.yaml.example
Normal file
111
combined/config-simple.yaml.example
Normal file
@@ -0,0 +1,111 @@
|
||||
# NetBird Combined Server Configuration
|
||||
# Copy this file to config.yaml and customize for your deployment
|
||||
#
|
||||
# This is a Management server with optional embedded Signal, Relay, and STUN services.
|
||||
# By default, all services run locally. You can use external services instead by
|
||||
# setting the corresponding override fields.
|
||||
#
|
||||
# Architecture:
|
||||
# - Management: Always runs locally (this IS the management server)
|
||||
# - Signal: Local by default; set 'signalUri' to use external (disables local)
|
||||
# - Relay: Local by default; set 'relays' to use external (disables local)
|
||||
# - STUN: Local on port 3478 by default; set 'stuns' to use external instead
|
||||
|
||||
server:
|
||||
# Main HTTP/gRPC port for all services (Management, Signal, Relay)
|
||||
listenAddress: ":443"
|
||||
|
||||
# Public address that peers will use to connect to this server
|
||||
# Used for relay connections and management DNS domain
|
||||
# Format: protocol://hostname:port (e.g., https://server.mycompany.com:443)
|
||||
exposedAddress: "https://server.mycompany.com:443"
|
||||
|
||||
# STUN server ports (defaults to [3478] if not specified; set 'stuns' to use external)
|
||||
# stunPorts:
|
||||
# - 3478
|
||||
|
||||
# Metrics endpoint port
|
||||
metricsPort: 9090
|
||||
|
||||
# Healthcheck endpoint address
|
||||
healthcheckAddress: ":9000"
|
||||
|
||||
# Logging configuration
|
||||
logLevel: "info" # Default log level for all components: panic, fatal, error, warn, info, debug, trace
|
||||
logFile: "console" # "console" or path to log file
|
||||
|
||||
# TLS configuration (optional)
|
||||
tls:
|
||||
certFile: ""
|
||||
keyFile: ""
|
||||
letsencrypt:
|
||||
enabled: false
|
||||
dataDir: ""
|
||||
domains: []
|
||||
email: ""
|
||||
awsRoute53: false
|
||||
|
||||
# Shared secret for relay authentication (required when running local relay)
|
||||
authSecret: "your-secret-key-here"
|
||||
|
||||
# Data directory for all services
|
||||
dataDir: "/var/lib/netbird/"
|
||||
|
||||
# ============================================================================
|
||||
# External Service Overrides (optional)
|
||||
# Use these to point to external Signal, Relay, or STUN servers instead of
|
||||
# running them locally. When set, the corresponding local service is disabled.
|
||||
# ============================================================================
|
||||
|
||||
# External STUN servers - disables local STUN server
|
||||
# stuns:
|
||||
# - uri: "stun:stun.example.com:3478"
|
||||
# - uri: "stun:stun.example.com:3479"
|
||||
|
||||
# External relay servers - disables local relay server
|
||||
# relays:
|
||||
# addresses:
|
||||
# - "rels://relay.example.com:443"
|
||||
# credentialsTTL: "12h"
|
||||
# secret: "relay-shared-secret"
|
||||
|
||||
# External signal server - disables local signal server
|
||||
# signalUri: "https://signal.example.com:443"
|
||||
|
||||
# ============================================================================
|
||||
# Management Settings
|
||||
# ============================================================================
|
||||
|
||||
# Metrics and updates
|
||||
disableAnonymousMetrics: false
|
||||
disableGeoliteUpdate: false
|
||||
|
||||
# Embedded authentication/identity provider (Dex) configuration (always enabled)
|
||||
auth:
|
||||
# OIDC issuer URL - must be publicly accessible
|
||||
issuer: "https://server.mycompany.com/oauth2"
|
||||
localAuthDisabled: false
|
||||
signKeyRefreshEnabled: false
|
||||
# OAuth2 redirect URIs for dashboard
|
||||
dashboardRedirectURIs:
|
||||
- "https://app.netbird.io/nb-auth"
|
||||
- "https://app.netbird.io/nb-silent-auth"
|
||||
# OAuth2 redirect URIs for CLI
|
||||
cliRedirectURIs:
|
||||
- "http://localhost:53000/"
|
||||
# Optional initial admin user
|
||||
# owner:
|
||||
# email: "admin@example.com"
|
||||
# password: "initial-password"
|
||||
|
||||
# Store configuration
|
||||
store:
|
||||
engine: "sqlite" # sqlite, postgres, or mysql
|
||||
dsn: "" # Connection string for postgres or mysql
|
||||
encryptionKey: ""
|
||||
|
||||
# Reverse proxy settings (optional)
|
||||
# reverseProxy:
|
||||
# trustedHTTPProxies: []
|
||||
# trustedHTTPProxiesCount: 0
|
||||
# trustedPeers: []
|
||||
@@ -1,29 +1,11 @@
|
||||
# NetBird Combined Server Configuration
|
||||
# Simplified Combined NetBird Server Configuration
|
||||
# Copy this file to config.yaml and customize for your deployment
|
||||
#
|
||||
# This is a Management server with optional embedded Signal, Relay, and STUN services.
|
||||
# By default, all services run locally. You can use external services instead by
|
||||
# setting the corresponding override fields.
|
||||
#
|
||||
# Architecture:
|
||||
# - Management: Always runs locally (this IS the management server)
|
||||
# - Signal: Local by default; set 'signalUri' to use external (disables local)
|
||||
# - Relay: Local by default; set 'relays' to use external (disables local)
|
||||
# - STUN: Local on port 3478 by default; set 'stuns' to use external instead
|
||||
|
||||
# Server-wide settings
|
||||
server:
|
||||
# Main HTTP/gRPC port for all services (Management, Signal, Relay)
|
||||
listenAddress: ":443"
|
||||
|
||||
# Public address that peers will use to connect to this server
|
||||
# Used for relay connections and management DNS domain
|
||||
# Format: protocol://hostname:port (e.g., https://server.mycompany.com:443)
|
||||
exposedAddress: "https://server.mycompany.com:443"
|
||||
|
||||
# STUN server ports (defaults to [3478] if not specified; set 'stuns' to use external)
|
||||
# stunPorts:
|
||||
# - 3478
|
||||
|
||||
# Metrics endpoint port
|
||||
metricsPort: 9090
|
||||
|
||||
@@ -31,7 +13,7 @@ server:
|
||||
healthcheckAddress: ":9000"
|
||||
|
||||
# Logging configuration
|
||||
logLevel: "info" # Default log level for all components: panic, fatal, error, warn, info, debug, trace
|
||||
logLevel: "info" # panic, fatal, error, warn, info, debug, trace
|
||||
logFile: "console" # "console" or path to log file
|
||||
|
||||
# TLS configuration (optional)
|
||||
@@ -45,45 +27,53 @@ server:
|
||||
email: ""
|
||||
awsRoute53: false
|
||||
|
||||
# Shared secret for relay authentication (required when running local relay)
|
||||
# Relay service configuration
|
||||
relay:
|
||||
# Enable/disable the relay service
|
||||
enabled: true
|
||||
|
||||
# Public address that peers will use to connect to this relay
|
||||
# Format: hostname:port or ip:port
|
||||
exposedAddress: "relay.example.com:443"
|
||||
|
||||
# Shared secret for relay authentication (required when enabled)
|
||||
authSecret: "your-secret-key-here"
|
||||
|
||||
# Data directory for all services
|
||||
# Log level for relay (reserved for future use, currently uses global log level)
|
||||
logLevel: "info"
|
||||
|
||||
# Embedded STUN server (optional)
|
||||
stun:
|
||||
enabled: false
|
||||
ports: [3478]
|
||||
logLevel: "info"
|
||||
|
||||
# Signal service configuration
|
||||
signal:
|
||||
# Enable/disable the signal service
|
||||
enabled: true
|
||||
|
||||
# Log level for signal (reserved for future use, currently uses global log level)
|
||||
logLevel: "info"
|
||||
|
||||
# Management service configuration
|
||||
management:
|
||||
# Enable/disable the management service
|
||||
enabled: true
|
||||
|
||||
# Data directory for management service
|
||||
dataDir: "/var/lib/netbird/"
|
||||
|
||||
# ============================================================================
|
||||
# External Service Overrides (optional)
|
||||
# Use these to point to external Signal, Relay, or STUN servers instead of
|
||||
# running them locally. When set, the corresponding local service is disabled.
|
||||
# ============================================================================
|
||||
|
||||
# External STUN servers - disables local STUN server
|
||||
# stuns:
|
||||
# - uri: "stun:stun.example.com:3478"
|
||||
# - uri: "stun:stun.example.com:3479"
|
||||
|
||||
# External relay servers - disables local relay server
|
||||
# relays:
|
||||
# addresses:
|
||||
# - "rels://relay.example.com:443"
|
||||
# credentialsTTL: "12h"
|
||||
# secret: "relay-shared-secret"
|
||||
|
||||
# External signal server - disables local signal server
|
||||
# signalUri: "https://signal.example.com:443"
|
||||
|
||||
# ============================================================================
|
||||
# Management Settings
|
||||
# ============================================================================
|
||||
# DNS domain for the management server
|
||||
dnsDomain: ""
|
||||
|
||||
# Metrics and updates
|
||||
disableAnonymousMetrics: false
|
||||
disableGeoliteUpdate: false
|
||||
|
||||
# Embedded authentication/identity provider (Dex) configuration (always enabled)
|
||||
auth:
|
||||
# OIDC issuer URL - must be publicly accessible
|
||||
issuer: "https://example.com/oauth2"
|
||||
issuer: "https://management.example.com/oauth2"
|
||||
localAuthDisabled: false
|
||||
signKeyRefreshEnabled: false
|
||||
# OAuth2 redirect URIs for dashboard
|
||||
@@ -98,27 +88,28 @@ server:
|
||||
# email: "admin@example.com"
|
||||
# password: "initial-password"
|
||||
|
||||
# External STUN servers (for client config)
|
||||
stuns: []
|
||||
# - uri: "stun:stun.example.com:3478"
|
||||
|
||||
# External relay servers (for client config)
|
||||
relays:
|
||||
addresses: []
|
||||
# - "rels://relay.example.com:443"
|
||||
credentialsTTL: "12h"
|
||||
secret: ""
|
||||
|
||||
# External signal server URI (for client config)
|
||||
signalUri: ""
|
||||
|
||||
# Store configuration
|
||||
store:
|
||||
engine: "sqlite" # sqlite, postgres, or mysql
|
||||
dsn: "" # Connection string for postgres or mysql
|
||||
encryptionKey: ""
|
||||
# file: "" # Custom SQLite file path (optional, defaults to {dataDir}/store.db)
|
||||
|
||||
# Activity events store configuration (optional, defaults to sqlite in dataDir)
|
||||
# activityStore:
|
||||
# engine: "sqlite" # sqlite or postgres
|
||||
# dsn: "" # Connection string for postgres
|
||||
# file: "" # Custom SQLite file path (optional, defaults to {dataDir}/events.db)
|
||||
|
||||
# Auth (embedded IdP) store configuration (optional, defaults to sqlite3 in dataDir/idp.db)
|
||||
# authStore:
|
||||
# engine: "sqlite3" # sqlite3 or postgres
|
||||
# dsn: "" # Connection string for postgres (e.g., "host=localhost port=5432 user=postgres password=postgres dbname=netbird_idp sslmode=disable")
|
||||
# file: "" # Custom SQLite file path (optional, defaults to {dataDir}/idp.db)
|
||||
|
||||
# Reverse proxy settings (optional)
|
||||
# reverseProxy:
|
||||
# trustedHTTPProxies: []
|
||||
# trustedHTTPProxiesCount: 0
|
||||
# trustedPeers: []
|
||||
# Reverse proxy settings
|
||||
reverseProxy:
|
||||
trustedHTTPProxies: []
|
||||
trustedHTTPProxiesCount: 0
|
||||
trustedPeers: []
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package txt
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/levels"
|
||||
@@ -16,7 +18,7 @@ type TextFormatter struct {
|
||||
func NewTextFormatter() *TextFormatter {
|
||||
return &TextFormatter{
|
||||
levelDesc: levels.ValidLevelDesc,
|
||||
timestampFormat: "2006-01-02T15:04:05.000Z07:00",
|
||||
timestampFormat: time.RFC3339, // or RFC3339
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,6 @@ func TestLogTextFormat(t *testing.T) {
|
||||
result, _ := formatter.Format(someEntry)
|
||||
|
||||
parsedString := string(result)
|
||||
expectedString := "^2021-02-21T01:10:30.000Z WARN \\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] some/fancy/path.go:46: Some Message\\s+$"
|
||||
expectedString := "^2021-02-21T01:10:30Z WARN \\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] some/fancy/path.go:46: Some Message\\s+$"
|
||||
assert.Regexp(t, expectedString, parsedString)
|
||||
}
|
||||
|
||||
19
go.mod
19
go.mod
@@ -93,10 +93,10 @@ require (
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/testcontainers/testcontainers-go v0.37.0
|
||||
github.com/testcontainers/testcontainers-go/modules/mysql v0.37.0
|
||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.37.0
|
||||
github.com/testcontainers/testcontainers-go/modules/redis v0.37.0
|
||||
github.com/testcontainers/testcontainers-go v0.31.0
|
||||
github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0
|
||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0
|
||||
github.com/testcontainers/testcontainers-go/modules/redis v0.31.0
|
||||
github.com/things-go/go-socks5 v0.0.4
|
||||
github.com/ti-mo/conntrack v0.5.1
|
||||
github.com/ti-mo/netfilter v0.5.2
|
||||
@@ -142,6 +142,7 @@ require (
|
||||
github.com/Masterminds/semver/v3 v3.3.0 // indirect
|
||||
github.com/Masterminds/sprig/v3 v3.3.0 // indirect
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/Microsoft/hcsshim v0.12.3 // indirect
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
|
||||
github.com/awnumar/memcall v0.4.0 // indirect
|
||||
@@ -165,16 +166,16 @@ require (
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/caddyserver/zerossl v0.1.3 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/containerd/containerd v1.7.29 // indirect
|
||||
github.com/containerd/log v0.1.0 // indirect
|
||||
github.com/containerd/platforms v0.2.1 // indirect
|
||||
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/docker/docker v28.0.1+incompatible // indirect
|
||||
github.com/docker/docker v26.1.5+incompatible // indirect
|
||||
github.com/docker/go-connections v0.5.0 // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/ebitengine/purego v0.8.2 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fredbi/uri v1.1.1 // indirect
|
||||
github.com/fyne-io/gl-js v0.2.0 // indirect
|
||||
@@ -220,10 +221,9 @@ require (
|
||||
github.com/lib/pq v1.10.9 // indirect
|
||||
github.com/libdns/libdns v0.2.2 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
|
||||
github.com/magiconair/properties v1.8.10 // indirect
|
||||
github.com/magiconair/properties v1.8.7 // indirect
|
||||
github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.32 // indirect
|
||||
github.com/mdelapenya/tlscert v0.2.0 // indirect
|
||||
github.com/mdlayher/genetlink v1.3.2 // indirect
|
||||
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect
|
||||
github.com/mholt/acmez/v2 v2.0.1 // indirect
|
||||
@@ -242,7 +242,7 @@ require (
|
||||
github.com/nxadm/tail v1.4.8 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.0 // indirect
|
||||
github.com/pion/dtls/v2 v2.2.10 // indirect
|
||||
github.com/pion/dtls/v3 v3.0.9 // indirect
|
||||
github.com/pion/mdns/v2 v2.0.7 // indirect
|
||||
@@ -256,7 +256,6 @@ require (
|
||||
github.com/prometheus/procfs v0.16.1 // indirect
|
||||
github.com/russellhaering/goxmldsig v1.5.0 // indirect
|
||||
github.com/rymdport/portal v0.4.2 // indirect
|
||||
github.com/shirou/gopsutil/v4 v4.25.1 // indirect
|
||||
github.com/shoenig/go-m1cpu v0.1.6 // indirect
|
||||
github.com/shopspring/decimal v1.4.0 // indirect
|
||||
github.com/spf13/cast v1.7.0 // indirect
|
||||
|
||||
45
go.sum
45
go.sum
@@ -33,6 +33,8 @@ github.com/Masterminds/sprig/v3 v3.3.0 h1:mQh0Yrg1XPo6vjYXgtf5OtijNAKJRNcTdOOGZe
|
||||
github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSCzdgBfDb35Lz0=
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/Microsoft/hcsshim v0.12.3 h1:LS9NXqXhMoqNCplK1ApmVSfB4UnVLRDWRapB6EIlxE0=
|
||||
github.com/Microsoft/hcsshim v0.12.3/go.mod h1:Iyl1WVpZzr+UkzjekHZbV8o5Z9ZkxNGx6CtY2Qg/JVQ=
|
||||
github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk=
|
||||
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible h1:hqcTK6ZISdip65SR792lwYJTa/axESA0889D3UlZbLo=
|
||||
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible/go.mod h1:6B1nuc1MUs6c62ODZDl7hVE5Pv7O2XGSkgg2olnq34I=
|
||||
@@ -107,6 +109,8 @@ github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
|
||||
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
|
||||
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
||||
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
||||
github.com/containerd/containerd v1.7.29 h1:90fWABQsaN9mJhGkoVnuzEY+o1XDPbg9BTC9QTAHnuE=
|
||||
github.com/containerd/containerd v1.7.29/go.mod h1:azUkWcOvHrWvaiUjSQH0fjzuHIwSPg1WL5PshGP4Szs=
|
||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
||||
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
|
||||
@@ -131,14 +135,12 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/docker/docker v28.0.1+incompatible h1:FCHjSRdXhNRFjlHMTv4jUNlIBbTeRjrWfeFuJp7jpo0=
|
||||
github.com/docker/docker v28.0.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/docker v26.1.5+incompatible h1:NEAxTwEjxV6VbBMBoGG3zPqbiJosIApZjxlbrG9q3/g=
|
||||
github.com/docker/docker v26.1.5+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
|
||||
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
|
||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||
github.com/ebitengine/purego v0.8.2 h1:jPPGWs2sZ1UgOSgD2bClL0MJIqu58nOmIcBuXr62z1I=
|
||||
github.com/ebitengine/purego v0.8.2/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||
github.com/eko/gocache/lib/v4 v4.2.0 h1:MNykyi5Xw+5Wu3+PUrvtOCaKSZM1nUSVftbzmeC7Yuw=
|
||||
github.com/eko/gocache/lib/v4 v4.2.0/go.mod h1:7ViVmbU+CzDHzRpmB4SXKyyzyuJ8A3UW3/cszpcqB4M=
|
||||
github.com/eko/gocache/store/go_cache/v4 v4.2.2 h1:tAI9nl6TLoJyKG1ujF0CS0n/IgTEMl+NivxtR5R3/hw=
|
||||
@@ -193,6 +195,8 @@ github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3yg
|
||||
github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY=
|
||||
github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI=
|
||||
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
|
||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
|
||||
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
|
||||
@@ -353,15 +357,13 @@ github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81/go.mod h1:RD8ML/Y
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tAFlj1FYZl8ztUZ13bdq+PLY+NOfbyI=
|
||||
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k=
|
||||
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
||||
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
|
||||
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU=
|
||||
github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To=
|
||||
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
|
||||
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
|
||||
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||
github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o=
|
||||
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
|
||||
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
|
||||
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg=
|
||||
@@ -435,12 +437,13 @@ github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
|
||||
github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
|
||||
github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM=
|
||||
github.com/oschwald/maxminddb-golang v1.12.0 h1:9FnTOD0YOhP7DGxGsq4glzpGy5+w7pq50AS6wALUMYs=
|
||||
github.com/oschwald/maxminddb-golang v1.12.0/go.mod h1:q0Nob5lTCqyQ8WT6FYgS1L7PXKVVbgiymefNwIjPzgY=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||
github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3ve8=
|
||||
github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0=
|
||||
github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
|
||||
github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203 h1:E7Kmf11E4K7B5hDti2K2NqPb1nlYlGYsu02S1JNd/Bs=
|
||||
@@ -510,8 +513,6 @@ github.com/rymdport/portal v0.4.2 h1:7jKRSemwlTyVHHrTGgQg7gmNPJs88xkbKcIL3NlcmSU
|
||||
github.com/rymdport/portal v0.4.2/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4=
|
||||
github.com/shirou/gopsutil/v3 v3.24.4 h1:dEHgzZXt4LMNm+oYELpzl9YCqV65Yr/6SfrvgRBtXeU=
|
||||
github.com/shirou/gopsutil/v3 v3.24.4/go.mod h1:lTd2mdiOspcqLgAnr9/nGi71NkeMpWKdmhuxm9GusH8=
|
||||
github.com/shirou/gopsutil/v4 v4.25.1 h1:QSWkTc+fu9LTAWfkZwZ6j8MSUk4A2LV7rbH0ZqmLjXs=
|
||||
github.com/shirou/gopsutil/v4 v4.25.1/go.mod h1:RoUCUpndaJFtT+2zsZzzmhvbfGoDCJ7nFXKJf8GqJbI=
|
||||
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
|
||||
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
|
||||
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
|
||||
@@ -553,14 +554,14 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/testcontainers/testcontainers-go v0.37.0 h1:L2Qc0vkTw2EHWQ08djon0D2uw7Z/PtHS/QzZZ5Ra/hg=
|
||||
github.com/testcontainers/testcontainers-go v0.37.0/go.mod h1:QPzbxZhQ6Bclip9igjLFj6z0hs01bU8lrl2dHQmgFGM=
|
||||
github.com/testcontainers/testcontainers-go/modules/mysql v0.37.0 h1:LqUos1oR5iuuzorFnSvxsHNdYdCHB/DfI82CuT58wbI=
|
||||
github.com/testcontainers/testcontainers-go/modules/mysql v0.37.0/go.mod h1:vHEEHx5Kf+uq5hveaVAMrTzPY8eeRZcKcl23MRw5Tkc=
|
||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.37.0 h1:hsVwFkS6s+79MbKEO+W7A1wNIw1fmkMtF4fg83m6kbc=
|
||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.37.0/go.mod h1:Qj/eGbRbO/rEYdcRLmN+bEojzatP/+NS1y8ojl2PQsc=
|
||||
github.com/testcontainers/testcontainers-go/modules/redis v0.37.0 h1:9HIY28I9ME/Zmb+zey1p/I1mto5+5ch0wLX+nJdOsQ4=
|
||||
github.com/testcontainers/testcontainers-go/modules/redis v0.37.0/go.mod h1:Abu9g/25Qv+FkYVx3U4Voaynou1c+7D0HIhaQJXvk6E=
|
||||
github.com/testcontainers/testcontainers-go v0.31.0 h1:W0VwIhcEVhRflwL9as3dhY6jXjVCA27AkmbnZ+UTh3U=
|
||||
github.com/testcontainers/testcontainers-go v0.31.0/go.mod h1:D2lAoA0zUFiSY+eAflqK5mcUx/A5hrrORaEQrd0SefI=
|
||||
github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0 h1:790+S8ewZYCbG+o8IiFlZ8ZZ33XbNO6zV9qhU6xhlRk=
|
||||
github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0/go.mod h1:REFmO+lSG9S6uSBEwIMZCxeI36uhScjTwChYADeO3JA=
|
||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0 h1:isAwFS3KNKRbJMbWv+wolWqOFUECmjYZ+sIRZCIBc/E=
|
||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0/go.mod h1:ZNYY8vumNCEG9YI59A9d6/YaMY49uwRhmeU563EzFGw=
|
||||
github.com/testcontainers/testcontainers-go/modules/redis v0.31.0 h1:5X6GhOdLwV86zcW8sxppJAMtsDC9u+r9tb3biBc9GKs=
|
||||
github.com/testcontainers/testcontainers-go/modules/redis v0.31.0/go.mod h1:dKi5xBwy1k4u8yb3saQHu7hMEJwewHXxzbcMAuLiA6o=
|
||||
github.com/things-go/go-socks5 v0.0.4 h1:jMQjIc+qhD4z9cITOMnBiwo9dDmpGuXmBlkRFrl/qD0=
|
||||
github.com/things-go/go-socks5 v0.0.4/go.mod h1:sh4K6WHrmHZpjxLTCHyYtXYH8OUuD+yZun41NomR1IQ=
|
||||
github.com/ti-mo/conntrack v0.5.1 h1:opEwkFICnDbQc0BUXl73PHBK0h23jEIFVjXsqvF4GY0=
|
||||
@@ -850,7 +851,7 @@ gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDa
|
||||
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
||||
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
||||
gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
|
||||
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
||||
gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
|
||||
gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
||||
gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c h1:pfzmXIkkDgydR4ZRP+e1hXywZfYR21FA0Fbk6ptMkiA=
|
||||
gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c/go.mod h1:/mc6CfwbOm5KKmqoV7Qx20Q+Ja8+vO4g7FuCdlVoAfQ=
|
||||
|
||||
@@ -5,10 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
@@ -198,175 +195,11 @@ func (s *Storage) OpenStorage(logger *slog.Logger) (storage.Storage, error) {
|
||||
return nil, fmt.Errorf("sqlite3 storage requires 'file' config")
|
||||
}
|
||||
return (&sql.SQLite3{File: file}).Open(logger)
|
||||
case "postgres":
|
||||
dsn, _ := s.Config["dsn"].(string)
|
||||
if dsn == "" {
|
||||
return nil, fmt.Errorf("postgres storage requires 'dsn' config")
|
||||
}
|
||||
pg, err := parsePostgresDSN(dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid postgres DSN: %w", err)
|
||||
}
|
||||
return pg.Open(logger)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported storage type: %s", s.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// parsePostgresDSN parses a DSN into a sql.Postgres config.
|
||||
// It accepts both URI format (postgres://user:pass@host:port/dbname?sslmode=disable)
|
||||
// and libpq key=value format (host=localhost port=5432 dbname=mydb), including quoted values.
|
||||
func parsePostgresDSN(dsn string) (*sql.Postgres, error) {
|
||||
var params map[string]string
|
||||
var err error
|
||||
|
||||
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
|
||||
params, err = parsePostgresURI(dsn)
|
||||
} else {
|
||||
params, err = parsePostgresKeyValue(dsn)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
host := params["host"]
|
||||
if host == "" {
|
||||
host = "localhost"
|
||||
}
|
||||
|
||||
var port uint16 = 5432
|
||||
if p, ok := params["port"]; ok && p != "" {
|
||||
v, err := strconv.ParseUint(p, 10, 16)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid port %q: %w", p, err)
|
||||
}
|
||||
if v == 0 {
|
||||
return nil, fmt.Errorf("invalid port %q: must be non-zero", p)
|
||||
}
|
||||
port = uint16(v)
|
||||
}
|
||||
|
||||
dbname := params["dbname"]
|
||||
if dbname == "" {
|
||||
return nil, fmt.Errorf("dbname is required in DSN")
|
||||
}
|
||||
|
||||
pg := &sql.Postgres{
|
||||
NetworkDB: sql.NetworkDB{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Database: dbname,
|
||||
User: params["user"],
|
||||
Password: params["password"],
|
||||
},
|
||||
}
|
||||
|
||||
if sslMode := params["sslmode"]; sslMode != "" {
|
||||
switch sslMode {
|
||||
case "disable", "allow", "prefer", "require", "verify-ca", "verify-full":
|
||||
pg.SSL.Mode = sslMode
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported sslmode %q: valid values are disable, allow, prefer, require, verify-ca, verify-full", sslMode)
|
||||
}
|
||||
}
|
||||
|
||||
return pg, nil
|
||||
}
|
||||
|
||||
// parsePostgresURI parses a postgres:// or postgresql:// URI into parameter key-value pairs.
|
||||
func parsePostgresURI(dsn string) (map[string]string, error) {
|
||||
u, err := url.Parse(dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid postgres URI: %w", err)
|
||||
}
|
||||
|
||||
params := make(map[string]string)
|
||||
|
||||
if u.User != nil {
|
||||
params["user"] = u.User.Username()
|
||||
if p, ok := u.User.Password(); ok {
|
||||
params["password"] = p
|
||||
}
|
||||
}
|
||||
if u.Hostname() != "" {
|
||||
params["host"] = u.Hostname()
|
||||
}
|
||||
if u.Port() != "" {
|
||||
params["port"] = u.Port()
|
||||
}
|
||||
|
||||
dbname := strings.TrimPrefix(u.Path, "/")
|
||||
if dbname != "" {
|
||||
params["dbname"] = dbname
|
||||
}
|
||||
|
||||
for k, v := range u.Query() {
|
||||
if len(v) > 0 {
|
||||
params[k] = v[0]
|
||||
}
|
||||
}
|
||||
|
||||
return params, nil
|
||||
}
|
||||
|
||||
// parsePostgresKeyValue parses a libpq key=value DSN string, handling single-quoted values
|
||||
// (e.g., password='my pass' host=localhost).
|
||||
func parsePostgresKeyValue(dsn string) (map[string]string, error) {
|
||||
params := make(map[string]string)
|
||||
s := strings.TrimSpace(dsn)
|
||||
|
||||
for s != "" {
|
||||
eqIdx := strings.IndexByte(s, '=')
|
||||
if eqIdx < 0 {
|
||||
break
|
||||
}
|
||||
key := strings.TrimSpace(s[:eqIdx])
|
||||
|
||||
value, rest, err := parseDSNValue(s[eqIdx+1:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w for key %q", err, key)
|
||||
}
|
||||
|
||||
params[key] = value
|
||||
s = strings.TrimSpace(rest)
|
||||
}
|
||||
|
||||
return params, nil
|
||||
}
|
||||
|
||||
// parseDSNValue parses the next value from a libpq key=value string positioned after the '='.
|
||||
// It returns the parsed value and the remaining unparsed string.
|
||||
func parseDSNValue(s string) (value, rest string, err error) {
|
||||
if len(s) > 0 && s[0] == '\'' {
|
||||
return parseQuotedDSNValue(s[1:])
|
||||
}
|
||||
// Unquoted value: read until whitespace.
|
||||
idx := strings.IndexAny(s, " \t\n")
|
||||
if idx < 0 {
|
||||
return s, "", nil
|
||||
}
|
||||
return s[:idx], s[idx:], nil
|
||||
}
|
||||
|
||||
// parseQuotedDSNValue parses a single-quoted value starting after the opening quote.
|
||||
// Libpq uses ” to represent a literal single quote inside quoted values.
|
||||
func parseQuotedDSNValue(s string) (value, rest string, err error) {
|
||||
var buf strings.Builder
|
||||
for len(s) > 0 {
|
||||
if s[0] == '\'' {
|
||||
if len(s) > 1 && s[1] == '\'' {
|
||||
buf.WriteByte('\'')
|
||||
s = s[2:]
|
||||
continue
|
||||
}
|
||||
return buf.String(), s[1:], nil
|
||||
}
|
||||
buf.WriteByte(s[0])
|
||||
s = s[1:]
|
||||
}
|
||||
return "", "", fmt.Errorf("unterminated quoted value")
|
||||
}
|
||||
|
||||
// Validate validates the configuration
|
||||
func (c *YAMLConfig) Validate() error {
|
||||
if c.Issuer == "" {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -63,8 +63,6 @@ type Controller struct {
|
||||
|
||||
expNewNetworkMap bool
|
||||
expNewNetworkMapAIDs map[string]struct{}
|
||||
|
||||
compactedNetworkMap bool
|
||||
}
|
||||
|
||||
type bufferUpdate struct {
|
||||
@@ -87,12 +85,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
||||
newNetworkMapBuilder = false
|
||||
}
|
||||
|
||||
compactedNetworkMap, err := strconv.ParseBool(os.Getenv(types.EnvNewNetworkMapCompacted))
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", types.EnvNewNetworkMapCompacted, err)
|
||||
compactedNetworkMap = false
|
||||
}
|
||||
|
||||
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
|
||||
expIDs := make(map[string]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
@@ -116,8 +108,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
||||
holder: types.NewHolder(),
|
||||
expNewNetworkMap: newNetworkMapBuilder,
|
||||
expNewNetworkMapAIDs: expIDs,
|
||||
|
||||
compactedNetworkMap: compactedNetworkMap,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -240,12 +230,9 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
switch {
|
||||
case c.experimentalNetworkMap(accountID):
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||
case c.compactedNetworkMap:
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
default:
|
||||
} else {
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
|
||||
@@ -368,12 +355,9 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
switch {
|
||||
case c.experimentalNetworkMap(accountId):
|
||||
if c.experimentalNetworkMap(accountId) {
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||
case c.compactedNetworkMap:
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
default:
|
||||
} else {
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
|
||||
@@ -495,12 +479,7 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
} else {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
if c.compactedNetworkMap {
|
||||
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
} else {
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers())
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
@@ -875,12 +854,7 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
||||
account.InjectProxyPolicies(ctx)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
if c.compactedNetworkMap {
|
||||
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
} else {
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
}
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
|
||||
@@ -3,7 +3,6 @@ package accesslogs
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -12,39 +11,15 @@ const (
|
||||
DefaultPageSize = 50
|
||||
// MaxPageSize is the maximum number of records allowed per page
|
||||
MaxPageSize = 100
|
||||
|
||||
// Default sorting
|
||||
DefaultSortBy = "timestamp"
|
||||
DefaultSortOrder = "desc"
|
||||
)
|
||||
|
||||
// Valid sortable fields mapped to their database column names or expressions
|
||||
// For multi-column sorts, columns are separated by comma (e.g., "host, path")
|
||||
var validSortFields = map[string]string{
|
||||
"timestamp": "timestamp",
|
||||
"url": "host, path", // Sort by host first, then path
|
||||
"host": "host",
|
||||
"path": "path",
|
||||
"method": "method",
|
||||
"status_code": "status_code",
|
||||
"duration": "duration",
|
||||
"source_ip": "location_connection_ip",
|
||||
"user_id": "user_id",
|
||||
"auth_method": "auth_method_used",
|
||||
"reason": "reason",
|
||||
}
|
||||
|
||||
// AccessLogFilter holds pagination, filtering, and sorting parameters for access logs
|
||||
// AccessLogFilter holds pagination and filtering parameters for access logs
|
||||
type AccessLogFilter struct {
|
||||
// Page is the current page number (1-indexed)
|
||||
Page int
|
||||
// PageSize is the number of records per page
|
||||
PageSize int
|
||||
|
||||
// Sorting parameters
|
||||
SortBy string // Field to sort by: timestamp, url, host, path, method, status_code, duration, source_ip, user_id, auth_method, reason
|
||||
SortOrder string // Sort order: asc or desc (default: desc)
|
||||
|
||||
// Filtering parameters
|
||||
Search *string // General search across log ID, host, path, source IP, and user fields
|
||||
SourceIP *string // Filter by source IP address
|
||||
@@ -60,16 +35,13 @@ type AccessLogFilter struct {
|
||||
EndDate *time.Time // Filter by timestamp <= end_date
|
||||
}
|
||||
|
||||
// ParseFromRequest parses pagination, sorting, and filter parameters from HTTP request query parameters
|
||||
// ParseFromRequest parses pagination and filter parameters from HTTP request query parameters
|
||||
func (f *AccessLogFilter) ParseFromRequest(r *http.Request) {
|
||||
queryParams := r.URL.Query()
|
||||
|
||||
f.Page = parsePositiveInt(queryParams.Get("page"), 1)
|
||||
f.PageSize = min(parsePositiveInt(queryParams.Get("page_size"), DefaultPageSize), MaxPageSize)
|
||||
|
||||
f.SortBy = parseSortField(queryParams.Get("sort_by"))
|
||||
f.SortOrder = parseSortOrder(queryParams.Get("sort_order"))
|
||||
|
||||
f.Search = parseOptionalString(queryParams.Get("search"))
|
||||
f.SourceIP = parseOptionalString(queryParams.Get("source_ip"))
|
||||
f.Host = parseOptionalString(queryParams.Get("host"))
|
||||
@@ -135,44 +107,3 @@ func (f *AccessLogFilter) GetOffset() int {
|
||||
func (f *AccessLogFilter) GetLimit() int {
|
||||
return f.PageSize
|
||||
}
|
||||
|
||||
// GetSortColumn returns the validated database column name for sorting
|
||||
func (f *AccessLogFilter) GetSortColumn() string {
|
||||
if column, ok := validSortFields[f.SortBy]; ok {
|
||||
return column
|
||||
}
|
||||
return validSortFields[DefaultSortBy]
|
||||
}
|
||||
|
||||
// GetSortOrder returns the validated sort order (ASC or DESC)
|
||||
func (f *AccessLogFilter) GetSortOrder() string {
|
||||
if f.SortOrder == "asc" || f.SortOrder == "desc" {
|
||||
return f.SortOrder
|
||||
}
|
||||
return DefaultSortOrder
|
||||
}
|
||||
|
||||
// parseSortField validates and returns the sort field, defaulting if invalid
|
||||
func parseSortField(s string) string {
|
||||
if s == "" {
|
||||
return DefaultSortBy
|
||||
}
|
||||
// Check if the field is valid
|
||||
if _, ok := validSortFields[s]; ok {
|
||||
return s
|
||||
}
|
||||
return DefaultSortBy
|
||||
}
|
||||
|
||||
// parseSortOrder validates and returns the sort order, defaulting if invalid
|
||||
func parseSortOrder(s string) string {
|
||||
if s == "" {
|
||||
return DefaultSortOrder
|
||||
}
|
||||
// Normalize to lowercase
|
||||
s = strings.ToLower(s)
|
||||
if s == "asc" || s == "desc" {
|
||||
return s
|
||||
}
|
||||
return DefaultSortOrder
|
||||
}
|
||||
|
||||
@@ -361,205 +361,6 @@ func TestParseOptionalRFC3339(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLogFilter_SortingDefaults(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
|
||||
filter := &AccessLogFilter{}
|
||||
filter.ParseFromRequest(req)
|
||||
|
||||
assert.Equal(t, DefaultSortBy, filter.SortBy, "SortBy should default to timestamp")
|
||||
assert.Equal(t, DefaultSortOrder, filter.SortOrder, "SortOrder should default to desc")
|
||||
assert.Equal(t, "timestamp", filter.GetSortColumn(), "GetSortColumn should return timestamp")
|
||||
assert.Equal(t, "desc", filter.GetSortOrder(), "GetSortOrder should return desc")
|
||||
}
|
||||
|
||||
func TestAccessLogFilter_ValidSortFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sortBy string
|
||||
expectedColumn string
|
||||
expectedSortByVal string
|
||||
}{
|
||||
{"timestamp", "timestamp", "timestamp", "timestamp"},
|
||||
{"url", "url", "host, path", "url"},
|
||||
{"host", "host", "host", "host"},
|
||||
{"path", "path", "path", "path"},
|
||||
{"method", "method", "method", "method"},
|
||||
{"status_code", "status_code", "status_code", "status_code"},
|
||||
{"duration", "duration", "duration", "duration"},
|
||||
{"source_ip", "source_ip", "location_connection_ip", "source_ip"},
|
||||
{"user_id", "user_id", "user_id", "user_id"},
|
||||
{"auth_method", "auth_method", "auth_method_used", "auth_method"},
|
||||
{"reason", "reason", "reason", "reason"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test?sort_by="+tt.sortBy, nil)
|
||||
|
||||
filter := &AccessLogFilter{}
|
||||
filter.ParseFromRequest(req)
|
||||
|
||||
assert.Equal(t, tt.expectedSortByVal, filter.SortBy, "SortBy mismatch")
|
||||
assert.Equal(t, tt.expectedColumn, filter.GetSortColumn(), "GetSortColumn mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLogFilter_InvalidSortField(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sortBy string
|
||||
expected string
|
||||
}{
|
||||
{"invalid field", "invalid_field", DefaultSortBy},
|
||||
{"empty field", "", DefaultSortBy},
|
||||
{"malicious input", "timestamp--DROP", DefaultSortBy},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
q := req.URL.Query()
|
||||
q.Set("sort_by", tt.sortBy)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
filter := &AccessLogFilter{}
|
||||
filter.ParseFromRequest(req)
|
||||
|
||||
assert.Equal(t, tt.expected, filter.SortBy, "Invalid sort field should default to timestamp")
|
||||
assert.Equal(t, validSortFields[DefaultSortBy], filter.GetSortColumn())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLogFilter_SortOrder(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sortOrder string
|
||||
expected string
|
||||
}{
|
||||
{"ascending", "asc", "asc"},
|
||||
{"descending", "desc", "desc"},
|
||||
{"uppercase ASC", "ASC", "asc"},
|
||||
{"uppercase DESC", "DESC", "desc"},
|
||||
{"mixed case Asc", "Asc", "asc"},
|
||||
{"invalid order", "invalid", DefaultSortOrder},
|
||||
{"empty order", "", DefaultSortOrder},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test?sort_order="+tt.sortOrder, nil)
|
||||
|
||||
filter := &AccessLogFilter{}
|
||||
filter.ParseFromRequest(req)
|
||||
|
||||
assert.Equal(t, tt.expected, filter.GetSortOrder(), "GetSortOrder mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLogFilter_CompleteSortingScenarios(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sortBy string
|
||||
sortOrder string
|
||||
expectedColumn string
|
||||
expectedOrder string
|
||||
}{
|
||||
{
|
||||
name: "sort by host ascending",
|
||||
sortBy: "host",
|
||||
sortOrder: "asc",
|
||||
expectedColumn: "host",
|
||||
expectedOrder: "asc",
|
||||
},
|
||||
{
|
||||
name: "sort by duration descending",
|
||||
sortBy: "duration",
|
||||
sortOrder: "desc",
|
||||
expectedColumn: "duration",
|
||||
expectedOrder: "desc",
|
||||
},
|
||||
{
|
||||
name: "sort by status_code ascending",
|
||||
sortBy: "status_code",
|
||||
sortOrder: "asc",
|
||||
expectedColumn: "status_code",
|
||||
expectedOrder: "asc",
|
||||
},
|
||||
{
|
||||
name: "invalid sort with valid order",
|
||||
sortBy: "invalid",
|
||||
sortOrder: "asc",
|
||||
expectedColumn: "timestamp",
|
||||
expectedOrder: "asc",
|
||||
},
|
||||
{
|
||||
name: "valid sort with invalid order",
|
||||
sortBy: "method",
|
||||
sortOrder: "invalid",
|
||||
expectedColumn: "method",
|
||||
expectedOrder: DefaultSortOrder,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test?sort_by="+tt.sortBy+"&sort_order="+tt.sortOrder, nil)
|
||||
|
||||
filter := &AccessLogFilter{}
|
||||
filter.ParseFromRequest(req)
|
||||
|
||||
assert.Equal(t, tt.expectedColumn, filter.GetSortColumn())
|
||||
assert.Equal(t, tt.expectedOrder, filter.GetSortOrder())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSortField(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"valid field", "host", "host"},
|
||||
{"empty string", "", DefaultSortBy},
|
||||
{"invalid field", "invalid", DefaultSortBy},
|
||||
{"malicious input", "timestamp--DROP", DefaultSortBy},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := parseSortField(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSortOrder(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"asc lowercase", "asc", "asc"},
|
||||
{"desc lowercase", "desc", "desc"},
|
||||
{"ASC uppercase", "ASC", "asc"},
|
||||
{"DESC uppercase", "DESC", "desc"},
|
||||
{"invalid", "invalid", DefaultSortOrder},
|
||||
{"empty", "", DefaultSortOrder},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := parseSortOrder(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for creating pointers
|
||||
func strPtr(s string) *string {
|
||||
return &s
|
||||
|
||||
@@ -7,7 +7,4 @@ import (
|
||||
type Manager interface {
|
||||
SaveAccessLog(ctx context.Context, proxyLog *AccessLogEntry) error
|
||||
GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *AccessLogFilter) ([]*AccessLogEntry, int64, error)
|
||||
CleanupOldAccessLogs(ctx context.Context, retentionDays int) (int64, error)
|
||||
StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int)
|
||||
StopPeriodicCleanup()
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package manager
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -20,7 +19,6 @@ type managerImpl struct {
|
||||
store store.Store
|
||||
permissionsManager permissions.Manager
|
||||
geo geolocation.Geolocation
|
||||
cleanupCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, permissionsManager permissions.Manager, geo geolocation.Geolocation) accesslogs.Manager {
|
||||
@@ -80,74 +78,6 @@ func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID st
|
||||
return logs, totalCount, nil
|
||||
}
|
||||
|
||||
// CleanupOldAccessLogs deletes access logs older than the specified retention period
|
||||
func (m *managerImpl) CleanupOldAccessLogs(ctx context.Context, retentionDays int) (int64, error) {
|
||||
if retentionDays <= 0 {
|
||||
log.WithContext(ctx).Debug("access log cleanup skipped: retention days is 0 or negative")
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
cutoffTime := time.Now().AddDate(0, 0, -retentionDays)
|
||||
deletedCount, err := m.store.DeleteOldAccessLogs(ctx, cutoffTime)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to cleanup old access logs: %v", err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if deletedCount > 0 {
|
||||
log.WithContext(ctx).Infof("cleaned up %d access logs older than %d days", deletedCount, retentionDays)
|
||||
}
|
||||
|
||||
return deletedCount, nil
|
||||
}
|
||||
|
||||
// StartPeriodicCleanup starts a background goroutine that periodically cleans up old access logs
|
||||
func (m *managerImpl) StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int) {
|
||||
if retentionDays <= 0 {
|
||||
log.WithContext(ctx).Debug("periodic access log cleanup disabled: retention days is 0 or negative")
|
||||
return
|
||||
}
|
||||
|
||||
if cleanupIntervalHours <= 0 {
|
||||
cleanupIntervalHours = 24
|
||||
}
|
||||
|
||||
cleanupCtx, cancel := context.WithCancel(ctx)
|
||||
m.cleanupCancel = cancel
|
||||
|
||||
cleanupInterval := time.Duration(cleanupIntervalHours) * time.Hour
|
||||
ticker := time.NewTicker(cleanupInterval)
|
||||
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
|
||||
// Run cleanup immediately on startup
|
||||
log.WithContext(cleanupCtx).Infof("starting access log cleanup routine (retention: %d days, interval: %d hours)", retentionDays, cleanupIntervalHours)
|
||||
if _, err := m.CleanupOldAccessLogs(cleanupCtx, retentionDays); err != nil {
|
||||
log.WithContext(cleanupCtx).Errorf("initial access log cleanup failed: %v", err)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-cleanupCtx.Done():
|
||||
log.WithContext(cleanupCtx).Info("stopping access log cleanup routine")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if _, err := m.CleanupOldAccessLogs(cleanupCtx, retentionDays); err != nil {
|
||||
log.WithContext(cleanupCtx).Errorf("periodic access log cleanup failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// StopPeriodicCleanup stops the periodic cleanup routine
|
||||
func (m *managerImpl) StopPeriodicCleanup() {
|
||||
if m.cleanupCancel != nil {
|
||||
m.cleanupCancel()
|
||||
}
|
||||
}
|
||||
|
||||
// resolveUserFilters converts user email/name filters to user ID filter
|
||||
func (m *managerImpl) resolveUserFilters(ctx context.Context, accountID string, filter *accesslogs.AccessLogFilter) error {
|
||||
if filter.UserEmail == nil && filter.UserName == nil {
|
||||
|
||||
@@ -1,281 +0,0 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
func TestCleanupOldAccessLogs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
retentionDays int
|
||||
setupMock func(*store.MockStore)
|
||||
expectedCount int64
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "cleanup logs older than retention period",
|
||||
retentionDays: 30,
|
||||
setupMock: func(mockStore *store.MockStore) {
|
||||
mockStore.EXPECT().
|
||||
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, olderThan time.Time) (int64, error) {
|
||||
expectedCutoff := time.Now().AddDate(0, 0, -30)
|
||||
timeDiff := olderThan.Sub(expectedCutoff)
|
||||
if timeDiff.Abs() > time.Second {
|
||||
t.Errorf("cutoff time not as expected: got %v, want ~%v", olderThan, expectedCutoff)
|
||||
}
|
||||
return 5, nil
|
||||
})
|
||||
},
|
||||
expectedCount: 5,
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "no logs to cleanup",
|
||||
retentionDays: 30,
|
||||
setupMock: func(mockStore *store.MockStore) {
|
||||
mockStore.EXPECT().
|
||||
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||
Return(int64(0), nil)
|
||||
},
|
||||
expectedCount: 0,
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "zero retention days skips cleanup",
|
||||
retentionDays: 0,
|
||||
setupMock: func(mockStore *store.MockStore) {
|
||||
// No expectations - DeleteOldAccessLogs should not be called
|
||||
},
|
||||
expectedCount: 0,
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "negative retention days skips cleanup",
|
||||
retentionDays: -10,
|
||||
setupMock: func(mockStore *store.MockStore) {
|
||||
// No expectations - DeleteOldAccessLogs should not be called
|
||||
},
|
||||
expectedCount: 0,
|
||||
expectedError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
tt.setupMock(mockStore)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
deletedCount, err := manager.CleanupOldAccessLogs(ctx, tt.retentionDays)
|
||||
|
||||
if tt.expectedError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.expectedCount, deletedCount, "unexpected number of deleted logs")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanupWithExactBoundary(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
|
||||
mockStore.EXPECT().
|
||||
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, olderThan time.Time) (int64, error) {
|
||||
expectedCutoff := time.Now().AddDate(0, 0, -30)
|
||||
timeDiff := olderThan.Sub(expectedCutoff)
|
||||
assert.Less(t, timeDiff.Abs(), time.Second, "cutoff time should be close to expected value")
|
||||
return 1, nil
|
||||
})
|
||||
|
||||
manager := &managerImpl{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
deletedCount, err := manager.CleanupOldAccessLogs(ctx, 30)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), deletedCount)
|
||||
}
|
||||
|
||||
func TestStartPeriodicCleanup(t *testing.T) {
|
||||
t.Run("periodic cleanup disabled with zero retention", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
// No expectations - cleanup should not run
|
||||
|
||||
manager := &managerImpl{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
manager.StartPeriodicCleanup(ctx, 0, 1)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// If DeleteOldAccessLogs was called, the test will fail due to unexpected call
|
||||
})
|
||||
|
||||
t.Run("periodic cleanup runs immediately on start", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
|
||||
mockStore.EXPECT().
|
||||
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||
Return(int64(2), nil).
|
||||
Times(1)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
manager.StartPeriodicCleanup(ctx, 30, 24)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Expectations verified by gomock on defer ctrl.Finish()
|
||||
})
|
||||
|
||||
t.Run("periodic cleanup stops on context cancel", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
|
||||
mockStore.EXPECT().
|
||||
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||
Return(int64(1), nil).
|
||||
Times(1)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
manager.StartPeriodicCleanup(ctx, 30, 24)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
cancel()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
})
|
||||
|
||||
t.Run("cleanup interval defaults to 24 hours when invalid", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
|
||||
mockStore.EXPECT().
|
||||
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||
Return(int64(0), nil).
|
||||
Times(1)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
manager.StartPeriodicCleanup(ctx, 30, 0)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
manager.StopPeriodicCleanup()
|
||||
})
|
||||
|
||||
t.Run("cleanup interval uses configured hours", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
|
||||
mockStore.EXPECT().
|
||||
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||
Return(int64(3), nil).
|
||||
Times(1)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
manager.StartPeriodicCleanup(ctx, 30, 12)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
manager.StopPeriodicCleanup()
|
||||
})
|
||||
}
|
||||
|
||||
func TestStopPeriodicCleanup(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
|
||||
mockStore.EXPECT().
|
||||
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||
Return(int64(1), nil).
|
||||
Times(1)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
manager.StartPeriodicCleanup(ctx, 30, 24)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
manager.StopPeriodicCleanup()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Expectations verified by gomock - would fail if more than 1 call happened
|
||||
}
|
||||
|
||||
func TestStopPeriodicCleanup_NotStarted(t *testing.T) {
|
||||
manager := &managerImpl{}
|
||||
|
||||
// Should not panic if cleanup was never started
|
||||
manager.StopPeriodicCleanup()
|
||||
}
|
||||
@@ -9,5 +9,4 @@ type Manager interface {
|
||||
CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*Domain, error)
|
||||
DeleteDomain(ctx context.Context, accountID, userID, domainID string) error
|
||||
ValidateDomain(ctx context.Context, accountID, userID, domainID string)
|
||||
GetClusterDomains() []string
|
||||
}
|
||||
|
||||
@@ -27,21 +27,21 @@ type store interface {
|
||||
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
|
||||
}
|
||||
|
||||
type proxyManager interface {
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
type proxyURLProvider interface {
|
||||
GetConnectedProxyURLs() []string
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
store store
|
||||
validator domain.Validator
|
||||
proxyManager proxyManager
|
||||
proxyURLProvider proxyURLProvider
|
||||
permissionsManager permissions.Manager
|
||||
}
|
||||
|
||||
func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager) Manager {
|
||||
func NewManager(store store, proxyURLProvider proxyURLProvider, permissionsManager permissions.Manager) Manager {
|
||||
return Manager{
|
||||
store: store,
|
||||
proxyManager: proxyMgr,
|
||||
store: store,
|
||||
proxyURLProvider: proxyURLProvider,
|
||||
validator: domain.Validator{
|
||||
Resolver: net.DefaultResolver,
|
||||
},
|
||||
@@ -67,12 +67,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
||||
|
||||
// Add connected proxy clusters as free domains.
|
||||
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
allowList := m.proxyURLAllowList()
|
||||
log.WithFields(log.Fields{
|
||||
"accountID": accountID,
|
||||
"proxyAllowList": allowList,
|
||||
}).Debug("getting domains with proxy allow list")
|
||||
@@ -111,10 +107,7 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
|
||||
}
|
||||
|
||||
// Verify the target cluster is in the available clusters
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
||||
}
|
||||
allowList := m.proxyURLAllowList()
|
||||
clusterValid := false
|
||||
for _, cluster := range allowList {
|
||||
if cluster == targetCluster {
|
||||
@@ -228,26 +221,21 @@ func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID
|
||||
}
|
||||
}
|
||||
|
||||
// GetClusterDomains returns a list of proxy cluster domains.
|
||||
func (m Manager) GetClusterDomains() []string {
|
||||
if m.proxyManager == nil {
|
||||
return nil
|
||||
// proxyURLAllowList retrieves a list of currently connected proxies and
|
||||
// their URLs
|
||||
func (m Manager) proxyURLAllowList() []string {
|
||||
var reverseProxyAddresses []string
|
||||
if m.proxyURLProvider != nil {
|
||||
reverseProxyAddresses = m.proxyURLProvider.GetConnectedProxyURLs()
|
||||
}
|
||||
addresses, err := m.proxyManager.GetActiveClusterAddresses(context.Background())
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return addresses
|
||||
return reverseProxyAddresses
|
||||
}
|
||||
|
||||
// DeriveClusterFromDomain determines the proxy cluster for a given domain.
|
||||
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
|
||||
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
|
||||
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
||||
}
|
||||
allowList := m.proxyURLAllowList()
|
||||
if len(allowList) == 0 {
|
||||
return "", fmt.Errorf("no proxy clusters available")
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package service
|
||||
package reverseproxy
|
||||
|
||||
//go:generate go run github.com/golang/mock/mockgen -package service -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
|
||||
//go:generate go run github.com/golang/mock/mockgen -package reverseproxy -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -12,17 +12,12 @@ type Manager interface {
|
||||
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
||||
UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
||||
DeleteService(ctx context.Context, accountID, userID, serviceID string) error
|
||||
DeleteAllServices(ctx context.Context, accountID, userID string) error
|
||||
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
|
||||
SetStatus(ctx context.Context, accountID, serviceID string, status Status) error
|
||||
SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error
|
||||
ReloadAllServicesForAccount(ctx context.Context, accountID string) error
|
||||
ReloadService(ctx context.Context, accountID, serviceID string) error
|
||||
GetGlobalServices(ctx context.Context) ([]*Service, error)
|
||||
GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error)
|
||||
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
|
||||
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
|
||||
CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error)
|
||||
RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
|
||||
StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
|
||||
StartExposeReaper(ctx context.Context)
|
||||
}
|
||||
@@ -1,8 +1,8 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: ./interface.go
|
||||
|
||||
// Package service is a generated GoMock package.
|
||||
package service
|
||||
// Package reverseproxy is a generated GoMock package.
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
context "context"
|
||||
@@ -49,35 +49,6 @@ func (mr *MockManagerMockRecorder) CreateService(ctx, accountID, userID, service
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service)
|
||||
}
|
||||
|
||||
// CreateServiceFromPeer mocks base method.
|
||||
func (m *MockManager) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateServiceFromPeer", ctx, accountID, peerID, req)
|
||||
ret0, _ := ret[0].(*ExposeServiceResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CreateServiceFromPeer indicates an expected call of CreateServiceFromPeer.
|
||||
func (mr *MockManagerMockRecorder) CreateServiceFromPeer(ctx, accountID, peerID, req interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateServiceFromPeer", reflect.TypeOf((*MockManager)(nil).CreateServiceFromPeer), ctx, accountID, peerID, req)
|
||||
}
|
||||
|
||||
// DeleteAllServices mocks base method.
|
||||
func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAllServices", ctx, accountID, userID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAllServices indicates an expected call of DeleteAllServices.
|
||||
func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// DeleteService mocks base method.
|
||||
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -210,20 +181,6 @@ func (mr *MockManagerMockRecorder) ReloadService(ctx, accountID, serviceID inter
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadService", reflect.TypeOf((*MockManager)(nil).ReloadService), ctx, accountID, serviceID)
|
||||
}
|
||||
|
||||
// RenewServiceFromPeer mocks base method.
|
||||
func (m *MockManager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RenewServiceFromPeer", ctx, accountID, peerID, domain)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RenewServiceFromPeer indicates an expected call of RenewServiceFromPeer.
|
||||
func (mr *MockManagerMockRecorder) RenewServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewServiceFromPeer", reflect.TypeOf((*MockManager)(nil).RenewServiceFromPeer), ctx, accountID, peerID, domain)
|
||||
}
|
||||
|
||||
// SetCertificateIssuedAt mocks base method.
|
||||
func (m *MockManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -239,7 +196,7 @@ func (mr *MockManagerMockRecorder) SetCertificateIssuedAt(ctx, accountID, servic
|
||||
}
|
||||
|
||||
// SetStatus mocks base method.
|
||||
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status Status) error {
|
||||
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status)
|
||||
ret0, _ := ret[0].(error)
|
||||
@@ -252,32 +209,6 @@ func (mr *MockManagerMockRecorder) SetStatus(ctx, accountID, serviceID, status i
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetStatus", reflect.TypeOf((*MockManager)(nil).SetStatus), ctx, accountID, serviceID, status)
|
||||
}
|
||||
|
||||
// StartExposeReaper mocks base method.
|
||||
func (m *MockManager) StartExposeReaper(ctx context.Context) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "StartExposeReaper", ctx)
|
||||
}
|
||||
|
||||
// StartExposeReaper indicates an expected call of StartExposeReaper.
|
||||
func (mr *MockManagerMockRecorder) StartExposeReaper(ctx interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartExposeReaper", reflect.TypeOf((*MockManager)(nil).StartExposeReaper), ctx)
|
||||
}
|
||||
|
||||
// StopServiceFromPeer mocks base method.
|
||||
func (m *MockManager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "StopServiceFromPeer", ctx, accountID, peerID, domain)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// StopServiceFromPeer indicates an expected call of StopServiceFromPeer.
|
||||
func (mr *MockManagerMockRecorder) StopServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopServiceFromPeer", reflect.TypeOf((*MockManager)(nil).StopServiceFromPeer), ctx, accountID, peerID, domain)
|
||||
}
|
||||
|
||||
// UpdateService mocks base method.
|
||||
func (m *MockManager) UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6,10 +6,10 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
@@ -17,11 +17,11 @@ import (
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
manager rpservice.Manager
|
||||
manager reverseproxy.Manager
|
||||
}
|
||||
|
||||
// RegisterEndpoints registers all service HTTP endpoints.
|
||||
func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
|
||||
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
|
||||
h := &handler{
|
||||
manager: manager,
|
||||
}
|
||||
@@ -72,7 +72,7 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
service := new(rpservice.Service)
|
||||
service := new(reverseproxy.Service)
|
||||
service.FromAPIRequest(&req, userAuth.AccountId)
|
||||
|
||||
if err = service.Validate(); err != nil {
|
||||
@@ -130,7 +130,7 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
service := new(rpservice.Service)
|
||||
service := new(reverseproxy.Service)
|
||||
service.ID = serviceID
|
||||
service.FromAPIRequest(&req, userAuth.AccountId)
|
||||
|
||||
539
management/internals/modules/reverseproxy/manager/manager.go
Normal file
539
management/internals/modules/reverseproxy/manager/manager.go
Normal file
@@ -0,0 +1,539 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const unknownHostPlaceholder = "unknown"
|
||||
|
||||
// ClusterDeriver derives the proxy cluster from a domain.
|
||||
type ClusterDeriver interface {
|
||||
DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error)
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
proxyGRPCServer *nbgrpc.ProxyServiceServer
|
||||
clusterDeriver ClusterDeriver
|
||||
}
|
||||
|
||||
// NewManager creates a new service manager.
|
||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
proxyGRPCServer: proxyGRPCServer,
|
||||
clusterDeriver: clusterDeriver,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return services, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
||||
for _, target := range service.Targets {
|
||||
switch target.TargetType {
|
||||
case reverseproxy.TargetTypePeer:
|
||||
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err)
|
||||
target.Host = unknownHostPlaceholder
|
||||
continue
|
||||
}
|
||||
target.Host = peer.IP.String()
|
||||
case reverseproxy.TargetTypeHost:
|
||||
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
|
||||
target.Host = unknownHostPlaceholder
|
||||
continue
|
||||
}
|
||||
target.Host = resource.Prefix.Addr().String()
|
||||
case reverseproxy.TargetTypeDomain:
|
||||
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
|
||||
target.Host = unknownHostPlaceholder
|
||||
continue
|
||||
}
|
||||
target.Host = resource.Domain
|
||||
case reverseproxy.TargetTypeSubnet:
|
||||
// For subnets we do not do any lookups on the resource
|
||||
default:
|
||||
return fmt.Errorf("unknown target type: %s", target.TargetType)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get service: %w", err)
|
||||
}
|
||||
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.persistNewService(ctx, accountID, service); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta())
|
||||
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
||||
if m.clusterDeriver != nil {
|
||||
proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain)
|
||||
return status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
|
||||
}
|
||||
service.ProxyCluster = proxyCluster
|
||||
}
|
||||
|
||||
service.AccountID = accountID
|
||||
service.InitNewRecord()
|
||||
|
||||
if err := service.Auth.HashSecrets(); err != nil {
|
||||
return fmt.Errorf("hash secrets: %w", err)
|
||||
}
|
||||
|
||||
keyPair, err := sessionkey.GenerateKeyPair()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate session keys: %w", err)
|
||||
}
|
||||
service.SessionPrivateKey = keyPair.PrivateKey
|
||||
service.SessionPublicKey = keyPair.PublicKey
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) persistNewService(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("failed to create service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
|
||||
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||
return fmt.Errorf("failed to check existing service: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if existingService != nil && existingService.ID != excludeServiceID {
|
||||
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", domain)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := service.Auth.HashSecrets(); err != nil {
|
||||
return nil, fmt.Errorf("hash secrets: %w", err)
|
||||
}
|
||||
|
||||
updateInfo, err := m.persistServiceUpdate(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceUpdated, service.EventMeta())
|
||||
|
||||
if err := m.replaceHostByLookup(ctx, accountID, service); err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
m.sendServiceUpdateNotifications(service, updateInfo)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
type serviceUpdateInfo struct {
|
||||
oldCluster string
|
||||
domainChanged bool
|
||||
serviceEnabledChanged bool
|
||||
}
|
||||
|
||||
func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string, service *reverseproxy.Service) (*serviceUpdateInfo, error) {
|
||||
var updateInfo serviceUpdateInfo
|
||||
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateInfo.oldCluster = existingService.ProxyCluster
|
||||
updateInfo.domainChanged = existingService.Domain != service.Domain
|
||||
|
||||
if updateInfo.domainChanged {
|
||||
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
service.ProxyCluster = existingService.ProxyCluster
|
||||
}
|
||||
|
||||
m.preserveExistingAuthSecrets(service, existingService)
|
||||
m.preserveServiceMetadata(service, existingService)
|
||||
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
|
||||
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.UpdateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("update service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return &updateInfo, err
|
||||
}
|
||||
|
||||
func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *reverseproxy.Service) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.clusterDeriver != nil {
|
||||
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
|
||||
} else {
|
||||
service.ProxyCluster = newCluster
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reverseproxy.Service) {
|
||||
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
|
||||
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
|
||||
service.Auth.PasswordAuth.Password == "" {
|
||||
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
|
||||
}
|
||||
|
||||
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
|
||||
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
|
||||
service.Auth.PinAuth.Pin == "" {
|
||||
service.Auth.PinAuth = existingService.Auth.PinAuth
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) preserveServiceMetadata(service, existingService *reverseproxy.Service) {
|
||||
service.Meta = existingService.Meta
|
||||
service.SessionPrivateKey = existingService.SessionPrivateKey
|
||||
service.SessionPublicKey = existingService.SessionPublicKey
|
||||
}
|
||||
|
||||
func (m *managerImpl) sendServiceUpdateNotifications(service *reverseproxy.Service, updateInfo *serviceUpdateInfo) {
|
||||
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
|
||||
|
||||
switch {
|
||||
case updateInfo.domainChanged && updateInfo.oldCluster != service.ProxyCluster:
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), updateInfo.oldCluster)
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
|
||||
case !service.Enabled && updateInfo.serviceEnabledChanged:
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), service.ProxyCluster)
|
||||
case service.Enabled && updateInfo.serviceEnabledChanged:
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
|
||||
default:
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", oidcCfg), service.ProxyCluster)
|
||||
}
|
||||
}
|
||||
|
||||
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
|
||||
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*reverseproxy.Target) error {
|
||||
for _, target := range targets {
|
||||
switch target.TargetType {
|
||||
case reverseproxy.TargetTypePeer:
|
||||
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
|
||||
}
|
||||
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
|
||||
}
|
||||
case reverseproxy.TargetTypeHost, reverseproxy.TargetTypeSubnet, reverseproxy.TargetTypeDomain:
|
||||
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
|
||||
}
|
||||
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var service *reverseproxy.Service
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
|
||||
return fmt.Errorf("failed to delete service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta())
|
||||
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
|
||||
// Call this when receiving a gRPC notification that the certificate was issued.
|
||||
func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get service: %w", err)
|
||||
}
|
||||
|
||||
service.Meta.CertificateIssuedAt = time.Now()
|
||||
|
||||
if err = transaction.UpdateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("failed to update service certificate timestamp: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
|
||||
func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get service: %w", err)
|
||||
}
|
||||
|
||||
service.Meta.Status = string(status)
|
||||
|
||||
if err = transaction.UpdateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("failed to update service status: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID string) error {
|
||||
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get service: %w", err)
|
||||
}
|
||||
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get services: %w", err)
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
||||
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
err = m.replaceHostByLookup(ctx, service.AccountID, service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return services, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
|
||||
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get service: %w", err)
|
||||
}
|
||||
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return services, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
|
||||
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
return "", nil
|
||||
}
|
||||
return "", fmt.Errorf("failed to get service target by resource ID: %w", err)
|
||||
}
|
||||
|
||||
if target == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return target.ServiceID, nil
|
||||
}
|
||||
@@ -0,0 +1,375 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
func TestInitializeServiceForCreate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := "test-account"
|
||||
|
||||
t.Run("successful initialization without cluster deriver", func(t *testing.T) {
|
||||
mgr := &managerImpl{
|
||||
clusterDeriver: nil,
|
||||
}
|
||||
|
||||
service := &reverseproxy.Service{
|
||||
Domain: "example.com",
|
||||
Auth: reverseproxy.AuthConfig{},
|
||||
}
|
||||
|
||||
err := mgr.initializeServiceForCreate(ctx, accountID, service)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, accountID, service.AccountID)
|
||||
assert.Empty(t, service.ProxyCluster, "proxy cluster should be empty when no deriver")
|
||||
assert.NotEmpty(t, service.ID, "service ID should be initialized")
|
||||
assert.NotEmpty(t, service.SessionPrivateKey, "session private key should be generated")
|
||||
assert.NotEmpty(t, service.SessionPublicKey, "session public key should be generated")
|
||||
})
|
||||
|
||||
t.Run("verifies session keys are different", func(t *testing.T) {
|
||||
mgr := &managerImpl{
|
||||
clusterDeriver: nil,
|
||||
}
|
||||
|
||||
service1 := &reverseproxy.Service{Domain: "test1.com", Auth: reverseproxy.AuthConfig{}}
|
||||
service2 := &reverseproxy.Service{Domain: "test2.com", Auth: reverseproxy.AuthConfig{}}
|
||||
|
||||
err1 := mgr.initializeServiceForCreate(ctx, accountID, service1)
|
||||
err2 := mgr.initializeServiceForCreate(ctx, accountID, service2)
|
||||
|
||||
assert.NoError(t, err1)
|
||||
assert.NoError(t, err2)
|
||||
assert.NotEqual(t, service1.SessionPrivateKey, service2.SessionPrivateKey, "private keys should be unique")
|
||||
assert.NotEqual(t, service1.SessionPublicKey, service2.SessionPublicKey, "public keys should be unique")
|
||||
})
|
||||
}
|
||||
|
||||
func TestCheckDomainAvailable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := "test-account"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
excludeServiceID string
|
||||
setupMock func(*store.MockStore)
|
||||
expectedError bool
|
||||
errorType status.Type
|
||||
}{
|
||||
{
|
||||
name: "domain available - not found",
|
||||
domain: "available.com",
|
||||
excludeServiceID: "",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "available.com").
|
||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "domain already exists",
|
||||
domain: "exists.com",
|
||||
excludeServiceID: "",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||
Return(&reverseproxy.Service{ID: "existing-id", Domain: "exists.com"}, nil)
|
||||
},
|
||||
expectedError: true,
|
||||
errorType: status.AlreadyExists,
|
||||
},
|
||||
{
|
||||
name: "domain exists but excluded (same ID)",
|
||||
domain: "exists.com",
|
||||
excludeServiceID: "service-123",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "domain exists with different ID",
|
||||
domain: "exists.com",
|
||||
excludeServiceID: "service-456",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
||||
},
|
||||
expectedError: true,
|
||||
errorType: status.AlreadyExists,
|
||||
},
|
||||
{
|
||||
name: "store error (non-NotFound)",
|
||||
domain: "error.com",
|
||||
excludeServiceID: "",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "error.com").
|
||||
Return(nil, errors.New("database error"))
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
tt.setupMock(mockStore)
|
||||
|
||||
mgr := &managerImpl{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID)
|
||||
|
||||
if tt.expectedError {
|
||||
require.Error(t, err)
|
||||
if tt.errorType != 0 {
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok, "error should be a status error")
|
||||
assert.Equal(t, tt.errorType, sErr.Type())
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := "test-account"
|
||||
|
||||
t.Run("empty domain", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "").
|
||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||
|
||||
mgr := &managerImpl{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "")
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("empty exclude ID with existing service", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "test.com").
|
||||
Return(&reverseproxy.Service{ID: "some-id", Domain: "test.com"}, nil)
|
||||
|
||||
mgr := &managerImpl{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "")
|
||||
|
||||
assert.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, status.AlreadyExists, sErr.Type())
|
||||
})
|
||||
|
||||
t.Run("nil existing service with nil error", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "nil.com").
|
||||
Return(nil, nil)
|
||||
|
||||
mgr := &managerImpl{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "")
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPersistNewService(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := "test-account"
|
||||
|
||||
t.Run("successful service creation with no targets", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
service := &reverseproxy.Service{
|
||||
ID: "service-123",
|
||||
Domain: "new.com",
|
||||
Targets: []*reverseproxy.Target{},
|
||||
}
|
||||
|
||||
// Mock ExecuteInTransaction to execute the function immediately
|
||||
mockStore.EXPECT().
|
||||
ExecuteInTransaction(ctx, gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
|
||||
// Create another mock for the transaction
|
||||
txMock := store.NewMockStore(ctrl)
|
||||
txMock.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "new.com").
|
||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||
txMock.EXPECT().
|
||||
CreateService(ctx, service).
|
||||
Return(nil)
|
||||
|
||||
return fn(txMock)
|
||||
})
|
||||
|
||||
mgr := &managerImpl{store: mockStore}
|
||||
err := mgr.persistNewService(ctx, accountID, service)
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("domain already exists", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
service := &reverseproxy.Service{
|
||||
ID: "service-123",
|
||||
Domain: "existing.com",
|
||||
Targets: []*reverseproxy.Target{},
|
||||
}
|
||||
|
||||
mockStore.EXPECT().
|
||||
ExecuteInTransaction(ctx, gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
|
||||
txMock := store.NewMockStore(ctrl)
|
||||
txMock.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "existing.com").
|
||||
Return(&reverseproxy.Service{ID: "other-id", Domain: "existing.com"}, nil)
|
||||
|
||||
return fn(txMock)
|
||||
})
|
||||
|
||||
mgr := &managerImpl{store: mockStore}
|
||||
err := mgr.persistNewService(ctx, accountID, service)
|
||||
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, status.AlreadyExists, sErr.Type())
|
||||
})
|
||||
}
|
||||
func TestPreserveExistingAuthSecrets(t *testing.T) {
|
||||
mgr := &managerImpl{}
|
||||
|
||||
t.Run("preserve password when empty", func(t *testing.T) {
|
||||
existing := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "hashed-password",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
updated := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mgr.preserveExistingAuthSecrets(updated, existing)
|
||||
|
||||
assert.Equal(t, existing.Auth.PasswordAuth, updated.Auth.PasswordAuth)
|
||||
})
|
||||
|
||||
t.Run("preserve pin when empty", func(t *testing.T) {
|
||||
existing := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PinAuth: &reverseproxy.PINAuthConfig{
|
||||
Enabled: true,
|
||||
Pin: "hashed-pin",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
updated := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PinAuth: &reverseproxy.PINAuthConfig{
|
||||
Enabled: true,
|
||||
Pin: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mgr.preserveExistingAuthSecrets(updated, existing)
|
||||
|
||||
assert.Equal(t, existing.Auth.PinAuth, updated.Auth.PinAuth)
|
||||
})
|
||||
|
||||
t.Run("do not preserve when password is provided", func(t *testing.T) {
|
||||
existing := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "old-password",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
updated := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "new-password",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mgr.preserveExistingAuthSecrets(updated, existing)
|
||||
|
||||
assert.Equal(t, "new-password", updated.Auth.PasswordAuth.Password)
|
||||
assert.NotEqual(t, existing.Auth.PasswordAuth, updated.Auth.PasswordAuth)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPreserveServiceMetadata(t *testing.T) {
|
||||
mgr := &managerImpl{}
|
||||
|
||||
existing := &reverseproxy.Service{
|
||||
Meta: reverseproxy.ServiceMeta{
|
||||
CertificateIssuedAt: time.Now(),
|
||||
Status: "active",
|
||||
},
|
||||
SessionPrivateKey: "private-key",
|
||||
SessionPublicKey: "public-key",
|
||||
}
|
||||
|
||||
updated := &reverseproxy.Service{
|
||||
Domain: "updated.com",
|
||||
}
|
||||
|
||||
mgr.preserveServiceMetadata(updated, existing)
|
||||
|
||||
assert.Equal(t, existing.Meta, updated.Meta)
|
||||
assert.Equal(t, existing.SessionPrivateKey, updated.SessionPrivateKey)
|
||||
assert.Equal(t, existing.SessionPublicKey, updated.SessionPublicKey)
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
package proxy
|
||||
|
||||
//go:generate go run github.com/golang/mock/mockgen -package proxy -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// Manager defines the interface for proxy operations
|
||||
type Manager interface {
|
||||
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||
Disconnect(ctx context.Context, proxyID string) error
|
||||
Heartbeat(ctx context.Context, proxyID string) error
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
|
||||
}
|
||||
|
||||
// OIDCValidationConfig contains the OIDC configuration needed for token validation.
|
||||
type OIDCValidationConfig struct {
|
||||
Issuer string
|
||||
Audiences []string
|
||||
KeysLocation string
|
||||
MaxTokenAgeSeconds int64
|
||||
}
|
||||
|
||||
// Controller is responsible for managing proxy clusters and routing service updates.
|
||||
type Controller interface {
|
||||
SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string)
|
||||
GetOIDCValidationConfig() OIDCValidationConfig
|
||||
RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error
|
||||
UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error
|
||||
GetProxiesForCluster(clusterAddr string) []string
|
||||
}
|
||||
@@ -1,88 +0,0 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// GRPCController is a concrete implementation that manages proxy clusters and sends updates directly via gRPC.
|
||||
type GRPCController struct {
|
||||
proxyGRPCServer *nbgrpc.ProxyServiceServer
|
||||
// Map of cluster address -> set of proxy IDs
|
||||
clusterProxies sync.Map
|
||||
metrics *metrics
|
||||
}
|
||||
|
||||
// NewGRPCController creates a new GRPCController.
|
||||
func NewGRPCController(proxyGRPCServer *nbgrpc.ProxyServiceServer, meter metric.Meter) (*GRPCController, error) {
|
||||
m, err := newMetrics(meter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GRPCController{
|
||||
proxyGRPCServer: proxyGRPCServer,
|
||||
metrics: m,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SendServiceUpdateToCluster sends a service update to a specific proxy cluster.
|
||||
func (c *GRPCController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) {
|
||||
c.proxyGRPCServer.SendServiceUpdateToCluster(ctx, update, clusterAddr)
|
||||
c.metrics.IncrementServiceUpdateSendCount(clusterAddr)
|
||||
}
|
||||
|
||||
// GetOIDCValidationConfig returns the OIDC validation configuration from the gRPC server.
|
||||
func (c *GRPCController) GetOIDCValidationConfig() proxy.OIDCValidationConfig {
|
||||
return c.proxyGRPCServer.GetOIDCValidationConfig()
|
||||
}
|
||||
|
||||
// RegisterProxyToCluster registers a proxy to a specific cluster for routing.
|
||||
func (c *GRPCController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
|
||||
if clusterAddr == "" {
|
||||
return nil
|
||||
}
|
||||
proxySet, _ := c.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
|
||||
proxySet.(*sync.Map).Store(proxyID, struct{}{})
|
||||
log.WithContext(ctx).Debugf("Registered proxy %s to cluster %s", proxyID, clusterAddr)
|
||||
|
||||
c.metrics.IncrementProxyConnectionCount(clusterAddr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnregisterProxyFromCluster removes a proxy from a cluster.
|
||||
func (c *GRPCController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
|
||||
if clusterAddr == "" {
|
||||
return nil
|
||||
}
|
||||
if proxySet, ok := c.clusterProxies.Load(clusterAddr); ok {
|
||||
proxySet.(*sync.Map).Delete(proxyID)
|
||||
log.WithContext(ctx).Debugf("Unregistered proxy %s from cluster %s", proxyID, clusterAddr)
|
||||
|
||||
c.metrics.DecrementProxyConnectionCount(clusterAddr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetProxiesForCluster returns all proxy IDs registered for a specific cluster.
|
||||
func (c *GRPCController) GetProxiesForCluster(clusterAddr string) []string {
|
||||
proxySet, ok := c.clusterProxies.Load(clusterAddr)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
var proxies []string
|
||||
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
|
||||
proxies = append(proxies, key.(string))
|
||||
return true
|
||||
})
|
||||
return proxies
|
||||
}
|
||||
@@ -1,115 +0,0 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
)
|
||||
|
||||
// store defines the interface for proxy persistence operations
|
||||
type store interface {
|
||||
SaveProxy(ctx context.Context, p *proxy.Proxy) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
}
|
||||
|
||||
// Manager handles all proxy operations
|
||||
type Manager struct {
|
||||
store store
|
||||
metrics *metrics
|
||||
}
|
||||
|
||||
// NewManager creates a new proxy Manager
|
||||
func NewManager(store store, meter metric.Meter) (*Manager, error) {
|
||||
m, err := newMetrics(meter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Manager{
|
||||
store: store,
|
||||
metrics: m,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Connect registers a new proxy connection in the database
|
||||
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
now := time.Now()
|
||||
p := &proxy.Proxy{
|
||||
ID: proxyID,
|
||||
ClusterAddress: clusterAddress,
|
||||
IPAddress: ipAddress,
|
||||
LastSeen: now,
|
||||
ConnectedAt: &now,
|
||||
Status: "connected",
|
||||
}
|
||||
|
||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"proxyID": proxyID,
|
||||
"clusterAddress": clusterAddress,
|
||||
"ipAddress": ipAddress,
|
||||
}).Info("proxy connected")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disconnect marks a proxy as disconnected in the database
|
||||
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
|
||||
now := time.Now()
|
||||
p := &proxy.Proxy{
|
||||
ID: proxyID,
|
||||
Status: "disconnected",
|
||||
DisconnectedAt: &now,
|
||||
LastSeen: now,
|
||||
}
|
||||
|
||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"proxyID": proxyID,
|
||||
}).Info("proxy disconnected")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Heartbeat updates the proxy's last seen timestamp
|
||||
func (m Manager) Heartbeat(ctx context.Context, proxyID string) error {
|
||||
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
|
||||
return err
|
||||
}
|
||||
m.metrics.IncrementProxyHeartbeatCount()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
|
||||
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
|
||||
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
)
|
||||
|
||||
type metrics struct {
|
||||
proxyConnectionCount metric.Int64UpDownCounter
|
||||
serviceUpdateSendCount metric.Int64Counter
|
||||
proxyHeartbeatCount metric.Int64Counter
|
||||
}
|
||||
|
||||
func newMetrics(meter metric.Meter) (*metrics, error) {
|
||||
proxyConnectionCount, err := meter.Int64UpDownCounter(
|
||||
"management_proxy_connection_count",
|
||||
metric.WithDescription("Number of active proxy connections"),
|
||||
metric.WithUnit("{connection}"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serviceUpdateSendCount, err := meter.Int64Counter(
|
||||
"management_proxy_service_update_send_count",
|
||||
metric.WithDescription("Total number of service updates sent to proxies"),
|
||||
metric.WithUnit("{update}"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
proxyHeartbeatCount, err := meter.Int64Counter(
|
||||
"management_proxy_heartbeat_count",
|
||||
metric.WithDescription("Total number of proxy heartbeats received"),
|
||||
metric.WithUnit("{heartbeat}"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &metrics{
|
||||
proxyConnectionCount: proxyConnectionCount,
|
||||
serviceUpdateSendCount: serviceUpdateSendCount,
|
||||
proxyHeartbeatCount: proxyHeartbeatCount,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *metrics) IncrementProxyConnectionCount(clusterAddr string) {
|
||||
m.proxyConnectionCount.Add(context.Background(), 1,
|
||||
metric.WithAttributes(
|
||||
attribute.String("cluster", clusterAddr),
|
||||
))
|
||||
}
|
||||
|
||||
func (m *metrics) DecrementProxyConnectionCount(clusterAddr string) {
|
||||
m.proxyConnectionCount.Add(context.Background(), -1,
|
||||
metric.WithAttributes(
|
||||
attribute.String("cluster", clusterAddr),
|
||||
))
|
||||
}
|
||||
|
||||
func (m *metrics) IncrementServiceUpdateSendCount(clusterAddr string) {
|
||||
m.serviceUpdateSendCount.Add(context.Background(), 1,
|
||||
metric.WithAttributes(
|
||||
attribute.String("cluster", clusterAddr),
|
||||
))
|
||||
}
|
||||
|
||||
func (m *metrics) IncrementProxyHeartbeatCount() {
|
||||
m.proxyHeartbeatCount.Add(context.Background(), 1)
|
||||
}
|
||||
@@ -1,199 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: ./manager.go
|
||||
|
||||
// Package proxy is a generated GoMock package.
|
||||
package proxy
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
proto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// MockManager is a mock of Manager interface.
|
||||
type MockManager struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockManagerMockRecorder
|
||||
}
|
||||
|
||||
// MockManagerMockRecorder is the mock recorder for MockManager.
|
||||
type MockManagerMockRecorder struct {
|
||||
mock *MockManager
|
||||
}
|
||||
|
||||
// NewMockManager creates a new mock instance.
|
||||
func NewMockManager(ctrl *gomock.Controller) *MockManager {
|
||||
mock := &MockManager{ctrl: ctrl}
|
||||
mock.recorder = &MockManagerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockManager) EXPECT() *MockManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// CleanupStale mocks base method.
|
||||
func (m *MockManager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CleanupStale", ctx, inactivityDuration)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CleanupStale indicates an expected call of CleanupStale.
|
||||
func (mr *MockManagerMockRecorder) CleanupStale(ctx, inactivityDuration interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStale", reflect.TypeOf((*MockManager)(nil).CleanupStale), ctx, inactivityDuration)
|
||||
}
|
||||
|
||||
// Connect mocks base method.
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Connect indicates an expected call of Connect.
|
||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress)
|
||||
}
|
||||
|
||||
// Disconnect mocks base method.
|
||||
func (m *MockManager) Disconnect(ctx context.Context, proxyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Disconnect indicates an expected call of Disconnect.
|
||||
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID)
|
||||
}
|
||||
|
||||
// GetActiveClusterAddresses mocks base method.
|
||||
func (m *MockManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveClusterAddresses", ctx)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveClusterAddresses indicates an expected call of GetActiveClusterAddresses.
|
||||
func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx)
|
||||
}
|
||||
|
||||
// Heartbeat mocks base method.
|
||||
func (m *MockManager) Heartbeat(ctx context.Context, proxyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Heartbeat indicates an expected call of Heartbeat.
|
||||
func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID)
|
||||
}
|
||||
|
||||
// MockController is a mock of Controller interface.
|
||||
type MockController struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockControllerMockRecorder
|
||||
}
|
||||
|
||||
// MockControllerMockRecorder is the mock recorder for MockController.
|
||||
type MockControllerMockRecorder struct {
|
||||
mock *MockController
|
||||
}
|
||||
|
||||
// NewMockController creates a new mock instance.
|
||||
func NewMockController(ctrl *gomock.Controller) *MockController {
|
||||
mock := &MockController{ctrl: ctrl}
|
||||
mock.recorder = &MockControllerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockController) EXPECT() *MockControllerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetOIDCValidationConfig mocks base method.
|
||||
func (m *MockController) GetOIDCValidationConfig() OIDCValidationConfig {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetOIDCValidationConfig")
|
||||
ret0, _ := ret[0].(OIDCValidationConfig)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetOIDCValidationConfig indicates an expected call of GetOIDCValidationConfig.
|
||||
func (mr *MockControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockController)(nil).GetOIDCValidationConfig))
|
||||
}
|
||||
|
||||
// GetProxiesForCluster mocks base method.
|
||||
func (m *MockController) GetProxiesForCluster(clusterAddr string) []string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetProxiesForCluster", clusterAddr)
|
||||
ret0, _ := ret[0].([]string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetProxiesForCluster indicates an expected call of GetProxiesForCluster.
|
||||
func (mr *MockControllerMockRecorder) GetProxiesForCluster(clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxiesForCluster", reflect.TypeOf((*MockController)(nil).GetProxiesForCluster), clusterAddr)
|
||||
}
|
||||
|
||||
// RegisterProxyToCluster mocks base method.
|
||||
func (m *MockController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RegisterProxyToCluster", ctx, clusterAddr, proxyID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RegisterProxyToCluster indicates an expected call of RegisterProxyToCluster.
|
||||
func (mr *MockControllerMockRecorder) RegisterProxyToCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterProxyToCluster", reflect.TypeOf((*MockController)(nil).RegisterProxyToCluster), ctx, clusterAddr, proxyID)
|
||||
}
|
||||
|
||||
// SendServiceUpdateToCluster mocks base method.
|
||||
func (m *MockController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SendServiceUpdateToCluster", ctx, accountID, update, clusterAddr)
|
||||
}
|
||||
|
||||
// SendServiceUpdateToCluster indicates an expected call of SendServiceUpdateToCluster.
|
||||
func (mr *MockControllerMockRecorder) SendServiceUpdateToCluster(ctx, accountID, update, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendServiceUpdateToCluster", reflect.TypeOf((*MockController)(nil).SendServiceUpdateToCluster), ctx, accountID, update, clusterAddr)
|
||||
}
|
||||
|
||||
// UnregisterProxyFromCluster mocks base method.
|
||||
func (m *MockController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UnregisterProxyFromCluster", ctx, clusterAddr, proxyID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UnregisterProxyFromCluster indicates an expected call of UnregisterProxyFromCluster.
|
||||
func (mr *MockControllerMockRecorder) UnregisterProxyFromCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterProxyFromCluster", reflect.TypeOf((*MockController)(nil).UnregisterProxyFromCluster), ctx, clusterAddr, proxyID)
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import "time"
|
||||
|
||||
// Proxy represents a reverse proxy instance
|
||||
type Proxy struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)"`
|
||||
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
|
||||
IPAddress string `gorm:"type:varchar(45)"`
|
||||
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
|
||||
ConnectedAt *time.Time
|
||||
DisconnectedAt *time.Time
|
||||
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
func (Proxy) TableName() string {
|
||||
return "proxies"
|
||||
}
|
||||
@@ -1,20 +1,16 @@
|
||||
package service
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
|
||||
@@ -30,23 +26,20 @@ const (
|
||||
Delete Operation = "delete"
|
||||
)
|
||||
|
||||
type Status string
|
||||
type ProxyStatus string
|
||||
|
||||
const (
|
||||
StatusPending Status = "pending"
|
||||
StatusActive Status = "active"
|
||||
StatusTunnelNotCreated Status = "tunnel_not_created"
|
||||
StatusCertificatePending Status = "certificate_pending"
|
||||
StatusCertificateFailed Status = "certificate_failed"
|
||||
StatusError Status = "error"
|
||||
StatusPending ProxyStatus = "pending"
|
||||
StatusActive ProxyStatus = "active"
|
||||
StatusTunnelNotCreated ProxyStatus = "tunnel_not_created"
|
||||
StatusCertificatePending ProxyStatus = "certificate_pending"
|
||||
StatusCertificateFailed ProxyStatus = "certificate_failed"
|
||||
StatusError ProxyStatus = "error"
|
||||
|
||||
TargetTypePeer = "peer"
|
||||
TargetTypeHost = "host"
|
||||
TargetTypeDomain = "domain"
|
||||
TargetTypeSubnet = "subnet"
|
||||
|
||||
SourcePermanent = "permanent"
|
||||
SourceEphemeral = "ephemeral"
|
||||
)
|
||||
|
||||
type Target struct {
|
||||
@@ -112,11 +105,17 @@ func (a *AuthConfig) ClearSecrets() {
|
||||
}
|
||||
}
|
||||
|
||||
type Meta struct {
|
||||
type OIDCValidationConfig struct {
|
||||
Issuer string
|
||||
Audiences []string
|
||||
KeysLocation string
|
||||
MaxTokenAgeSeconds int64
|
||||
}
|
||||
|
||||
type ServiceMeta struct {
|
||||
CreatedAt time.Time
|
||||
CertificateIssuedAt *time.Time
|
||||
CertificateIssuedAt time.Time
|
||||
Status string
|
||||
LastRenewedAt *time.Time
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
@@ -129,12 +128,10 @@ type Service struct {
|
||||
Enabled bool
|
||||
PassHostHeader bool
|
||||
RewriteRedirects bool
|
||||
Auth AuthConfig `gorm:"serializer:json"`
|
||||
Meta Meta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
SessionPrivateKey string `gorm:"column:session_private_key"`
|
||||
SessionPublicKey string `gorm:"column:session_public_key"`
|
||||
Source string `gorm:"default:'permanent'"`
|
||||
SourcePeer string
|
||||
Auth AuthConfig `gorm:"serializer:json"`
|
||||
Meta ServiceMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
SessionPrivateKey string `gorm:"column:session_private_key"`
|
||||
SessionPublicKey string `gorm:"column:session_public_key"`
|
||||
}
|
||||
|
||||
func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service {
|
||||
@@ -159,7 +156,7 @@ func NewService(accountID, name, domain, proxyCluster string, targets []*Target,
|
||||
// only be called during initial creation, not for updates.
|
||||
func (s *Service) InitNewRecord() {
|
||||
s.ID = xid.New().String()
|
||||
s.Meta = Meta{
|
||||
s.Meta = ServiceMeta{
|
||||
CreatedAt: time.Now(),
|
||||
Status: string(StatusPending),
|
||||
}
|
||||
@@ -210,8 +207,8 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
Status: api.ServiceMetaStatus(s.Meta.Status),
|
||||
}
|
||||
|
||||
if s.Meta.CertificateIssuedAt != nil {
|
||||
meta.CertificateIssuedAt = s.Meta.CertificateIssuedAt
|
||||
if !s.Meta.CertificateIssuedAt.IsZero() {
|
||||
meta.CertificateIssuedAt = &s.Meta.CertificateIssuedAt
|
||||
}
|
||||
|
||||
resp := &api.Service{
|
||||
@@ -233,7 +230,7 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping {
|
||||
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping {
|
||||
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
|
||||
for _, target := range s.Targets {
|
||||
if !target.Enabled {
|
||||
@@ -406,11 +403,7 @@ func (s *Service) Validate() error {
|
||||
}
|
||||
|
||||
func (s *Service) EventMeta() map[string]any {
|
||||
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster, "source": s.Source, "auth": s.isAuthEnabled()}
|
||||
}
|
||||
|
||||
func (s *Service) isAuthEnabled() bool {
|
||||
return s.Auth.PasswordAuth != nil || s.Auth.PinAuth != nil || s.Auth.BearerAuth != nil
|
||||
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster}
|
||||
}
|
||||
|
||||
func (s *Service) Copy() *Service {
|
||||
@@ -434,8 +427,6 @@ func (s *Service) Copy() *Service {
|
||||
Meta: s.Meta,
|
||||
SessionPrivateKey: s.SessionPrivateKey,
|
||||
SessionPublicKey: s.SessionPublicKey,
|
||||
Source: s.Source,
|
||||
SourcePeer: s.SourcePeer,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -470,140 +461,3 @@ func (s *Service) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
|
||||
|
||||
const alphanumCharset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
var validNamePrefix = regexp.MustCompile(`^[a-z0-9]([a-z0-9-]{0,30}[a-z0-9])?$`)
|
||||
|
||||
// ExposeServiceRequest contains the parameters for creating a peer-initiated expose service.
|
||||
type ExposeServiceRequest struct {
|
||||
NamePrefix string
|
||||
Port int
|
||||
Protocol string
|
||||
Domain string
|
||||
Pin string
|
||||
Password string
|
||||
UserGroups []string
|
||||
}
|
||||
|
||||
// Validate checks all fields of the expose request.
|
||||
func (r *ExposeServiceRequest) Validate() error {
|
||||
if r == nil {
|
||||
return errors.New("request cannot be nil")
|
||||
}
|
||||
|
||||
if r.Port < 1 || r.Port > 65535 {
|
||||
return fmt.Errorf("port must be between 1 and 65535, got %d", r.Port)
|
||||
}
|
||||
|
||||
if r.Protocol != "http" && r.Protocol != "https" {
|
||||
return fmt.Errorf("unsupported protocol %q: must be http or https", r.Protocol)
|
||||
}
|
||||
|
||||
if r.Pin != "" && !pinRegexp.MatchString(r.Pin) {
|
||||
return errors.New("invalid pin: must be exactly 6 digits")
|
||||
}
|
||||
|
||||
for _, g := range r.UserGroups {
|
||||
if g == "" {
|
||||
return errors.New("user group name cannot be empty")
|
||||
}
|
||||
}
|
||||
|
||||
if r.NamePrefix != "" && !validNamePrefix.MatchString(r.NamePrefix) {
|
||||
return fmt.Errorf("invalid name prefix %q: must be lowercase alphanumeric with optional hyphens, 1-32 characters", r.NamePrefix)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToService builds a Service from the expose request.
|
||||
func (r *ExposeServiceRequest) ToService(accountID, peerID, serviceName string) *Service {
|
||||
service := &Service{
|
||||
AccountID: accountID,
|
||||
Name: serviceName,
|
||||
Enabled: true,
|
||||
Targets: []*Target{
|
||||
{
|
||||
AccountID: accountID,
|
||||
Port: r.Port,
|
||||
Protocol: r.Protocol,
|
||||
TargetId: peerID,
|
||||
TargetType: TargetTypePeer,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if r.Domain != "" {
|
||||
service.Domain = serviceName + "." + r.Domain
|
||||
}
|
||||
|
||||
if r.Pin != "" {
|
||||
service.Auth.PinAuth = &PINAuthConfig{
|
||||
Enabled: true,
|
||||
Pin: r.Pin,
|
||||
}
|
||||
}
|
||||
|
||||
if r.Password != "" {
|
||||
service.Auth.PasswordAuth = &PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: r.Password,
|
||||
}
|
||||
}
|
||||
|
||||
if len(r.UserGroups) > 0 {
|
||||
service.Auth.BearerAuth = &BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: r.UserGroups,
|
||||
}
|
||||
}
|
||||
|
||||
return service
|
||||
}
|
||||
|
||||
// ExposeServiceResponse contains the result of a successful peer expose creation.
|
||||
type ExposeServiceResponse struct {
|
||||
ServiceName string
|
||||
ServiceURL string
|
||||
Domain string
|
||||
}
|
||||
|
||||
// GenerateExposeName generates a random service name for peer-exposed services.
|
||||
// The prefix, if provided, must be a valid DNS label component (lowercase alphanumeric and hyphens).
|
||||
func GenerateExposeName(prefix string) (string, error) {
|
||||
if prefix != "" && !validNamePrefix.MatchString(prefix) {
|
||||
return "", fmt.Errorf("invalid name prefix %q: must be lowercase alphanumeric with optional hyphens, 1-32 characters", prefix)
|
||||
}
|
||||
|
||||
suffixLen := 12
|
||||
if prefix != "" {
|
||||
suffixLen = 4
|
||||
}
|
||||
|
||||
suffix, err := randomAlphanumeric(suffixLen)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate random name: %w", err)
|
||||
}
|
||||
|
||||
if prefix == "" {
|
||||
return suffix, nil
|
||||
}
|
||||
return prefix + "-" + suffix, nil
|
||||
}
|
||||
|
||||
func randomAlphanumeric(n int) (string, error) {
|
||||
result := make([]byte, n)
|
||||
charsetLen := big.NewInt(int64(len(alphanumCharset)))
|
||||
for i := range result {
|
||||
idx, err := rand.Int(rand.Reader, charsetLen)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
result[i] = alphanumCharset[idx.Int64()]
|
||||
}
|
||||
return string(result), nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package service
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
@@ -110,7 +109,7 @@ func TestIsDefaultPort(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestToProtoMapping_PortInTargetURL(t *testing.T) {
|
||||
oidcConfig := proxy.OIDCValidationConfig{}
|
||||
oidcConfig := OIDCValidationConfig{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -203,7 +202,7 @@ func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) {
|
||||
{TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true},
|
||||
},
|
||||
}
|
||||
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
|
||||
pm := rp.ToProtoMapping(Create, "token", OIDCValidationConfig{})
|
||||
require.Len(t, pm.Path, 1)
|
||||
assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target)
|
||||
}
|
||||
@@ -220,7 +219,7 @@ func TestToProtoMapping_OperationTypes(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.op), func(t *testing.T) {
|
||||
pm := rp.ToProtoMapping(tt.op, "", proxy.OIDCValidationConfig{})
|
||||
pm := rp.ToProtoMapping(tt.op, "", OIDCValidationConfig{})
|
||||
assert.Equal(t, tt.want, pm.Type)
|
||||
})
|
||||
}
|
||||
@@ -404,146 +403,3 @@ func TestAuthConfig_ClearSecrets(t *testing.T) {
|
||||
t.Errorf("PIN not cleared, got: %s", config.PinAuth.Pin)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateExposeName(t *testing.T) {
|
||||
t.Run("no prefix generates 12-char name", func(t *testing.T) {
|
||||
name, err := GenerateExposeName("")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, name, 12)
|
||||
assert.Regexp(t, `^[a-z0-9]+$`, name)
|
||||
})
|
||||
|
||||
t.Run("with prefix generates prefix-XXXX", func(t *testing.T) {
|
||||
name, err := GenerateExposeName("myapp")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, strings.HasPrefix(name, "myapp-"), "name should start with prefix")
|
||||
suffix := strings.TrimPrefix(name, "myapp-")
|
||||
assert.Len(t, suffix, 4, "suffix should be 4 chars")
|
||||
assert.Regexp(t, `^[a-z0-9]+$`, suffix)
|
||||
})
|
||||
|
||||
t.Run("unique names", func(t *testing.T) {
|
||||
names := make(map[string]bool)
|
||||
for i := 0; i < 50; i++ {
|
||||
name, err := GenerateExposeName("")
|
||||
require.NoError(t, err)
|
||||
names[name] = true
|
||||
}
|
||||
assert.Greater(t, len(names), 45, "should generate mostly unique names")
|
||||
})
|
||||
|
||||
t.Run("valid prefixes", func(t *testing.T) {
|
||||
validPrefixes := []string{"a", "ab", "a1", "my-app", "web-server-01", "a-b"}
|
||||
for _, prefix := range validPrefixes {
|
||||
name, err := GenerateExposeName(prefix)
|
||||
assert.NoError(t, err, "prefix %q should be valid", prefix)
|
||||
assert.True(t, strings.HasPrefix(name, prefix+"-"), "name should start with %q-", prefix)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid prefixes", func(t *testing.T) {
|
||||
invalidPrefixes := []string{
|
||||
"-starts-with-dash",
|
||||
"ends-with-dash-",
|
||||
"has.dots",
|
||||
"HAS-UPPER",
|
||||
"has spaces",
|
||||
"has/slash",
|
||||
"a--",
|
||||
}
|
||||
for _, prefix := range invalidPrefixes {
|
||||
_, err := GenerateExposeName(prefix)
|
||||
assert.Error(t, err, "prefix %q should be invalid", prefix)
|
||||
assert.Contains(t, err.Error(), "invalid name prefix")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExposeServiceRequest_ToService(t *testing.T) {
|
||||
t.Run("basic HTTP service", func(t *testing.T) {
|
||||
req := &ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
}
|
||||
|
||||
service := req.ToService("account-1", "peer-1", "mysvc")
|
||||
|
||||
assert.Equal(t, "account-1", service.AccountID)
|
||||
assert.Equal(t, "mysvc", service.Name)
|
||||
assert.True(t, service.Enabled)
|
||||
assert.Empty(t, service.Domain, "domain should be empty when not specified")
|
||||
require.Len(t, service.Targets, 1)
|
||||
|
||||
target := service.Targets[0]
|
||||
assert.Equal(t, 8080, target.Port)
|
||||
assert.Equal(t, "http", target.Protocol)
|
||||
assert.Equal(t, "peer-1", target.TargetId)
|
||||
assert.Equal(t, TargetTypePeer, target.TargetType)
|
||||
assert.True(t, target.Enabled)
|
||||
assert.Equal(t, "account-1", target.AccountID)
|
||||
})
|
||||
|
||||
t.Run("with custom domain", func(t *testing.T) {
|
||||
req := &ExposeServiceRequest{
|
||||
Port: 3000,
|
||||
Domain: "example.com",
|
||||
}
|
||||
|
||||
service := req.ToService("acc", "peer", "web")
|
||||
assert.Equal(t, "web.example.com", service.Domain)
|
||||
})
|
||||
|
||||
t.Run("with PIN auth", func(t *testing.T) {
|
||||
req := &ExposeServiceRequest{
|
||||
Port: 80,
|
||||
Pin: "1234",
|
||||
}
|
||||
|
||||
service := req.ToService("acc", "peer", "svc")
|
||||
require.NotNil(t, service.Auth.PinAuth)
|
||||
assert.True(t, service.Auth.PinAuth.Enabled)
|
||||
assert.Equal(t, "1234", service.Auth.PinAuth.Pin)
|
||||
assert.Nil(t, service.Auth.PasswordAuth)
|
||||
assert.Nil(t, service.Auth.BearerAuth)
|
||||
})
|
||||
|
||||
t.Run("with password auth", func(t *testing.T) {
|
||||
req := &ExposeServiceRequest{
|
||||
Port: 80,
|
||||
Password: "secret",
|
||||
}
|
||||
|
||||
service := req.ToService("acc", "peer", "svc")
|
||||
require.NotNil(t, service.Auth.PasswordAuth)
|
||||
assert.True(t, service.Auth.PasswordAuth.Enabled)
|
||||
assert.Equal(t, "secret", service.Auth.PasswordAuth.Password)
|
||||
})
|
||||
|
||||
t.Run("with user groups (bearer auth)", func(t *testing.T) {
|
||||
req := &ExposeServiceRequest{
|
||||
Port: 80,
|
||||
UserGroups: []string{"admins", "devs"},
|
||||
}
|
||||
|
||||
service := req.ToService("acc", "peer", "svc")
|
||||
require.NotNil(t, service.Auth.BearerAuth)
|
||||
assert.True(t, service.Auth.BearerAuth.Enabled)
|
||||
assert.Equal(t, []string{"admins", "devs"}, service.Auth.BearerAuth.DistributionGroups)
|
||||
})
|
||||
|
||||
t.Run("with all auth types", func(t *testing.T) {
|
||||
req := &ExposeServiceRequest{
|
||||
Port: 443,
|
||||
Domain: "myco.com",
|
||||
Pin: "9999",
|
||||
Password: "pass",
|
||||
UserGroups: []string{"ops"},
|
||||
}
|
||||
|
||||
service := req.ToService("acc", "peer", "full")
|
||||
assert.Equal(t, "full.myco.com", service.Domain)
|
||||
require.NotNil(t, service.Auth.PinAuth)
|
||||
require.NotNil(t, service.Auth.PasswordAuth)
|
||||
require.NotNil(t, service.Auth.BearerAuth)
|
||||
})
|
||||
}
|
||||
@@ -1,163 +0,0 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
exposeTTL = 90 * time.Second
|
||||
exposeReapInterval = 30 * time.Second
|
||||
maxExposesPerPeer = 10
|
||||
)
|
||||
|
||||
type trackedExpose struct {
|
||||
mu sync.Mutex
|
||||
domain string
|
||||
accountID string
|
||||
peerID string
|
||||
lastRenewed time.Time
|
||||
expiring bool
|
||||
}
|
||||
|
||||
type exposeTracker struct {
|
||||
activeExposes sync.Map
|
||||
exposeCreateMu sync.Mutex
|
||||
manager *Manager
|
||||
}
|
||||
|
||||
func exposeKey(peerID, domain string) string {
|
||||
return peerID + ":" + domain
|
||||
}
|
||||
|
||||
// TrackExposeIfAllowed atomically checks the per-peer limit and registers a new
|
||||
// active expose session under the same lock. Returns (true, false) if the expose
|
||||
// was already tracked (duplicate), (false, true) if tracking succeeded, and
|
||||
// (false, false) if the peer has reached the limit.
|
||||
func (t *exposeTracker) TrackExposeIfAllowed(peerID, domain, accountID string) (alreadyTracked, ok bool) {
|
||||
t.exposeCreateMu.Lock()
|
||||
defer t.exposeCreateMu.Unlock()
|
||||
|
||||
key := exposeKey(peerID, domain)
|
||||
_, loaded := t.activeExposes.LoadOrStore(key, &trackedExpose{
|
||||
domain: domain,
|
||||
accountID: accountID,
|
||||
peerID: peerID,
|
||||
lastRenewed: time.Now(),
|
||||
})
|
||||
if loaded {
|
||||
return true, false
|
||||
}
|
||||
|
||||
if t.CountPeerExposes(peerID) > maxExposesPerPeer {
|
||||
t.activeExposes.Delete(key)
|
||||
return false, false
|
||||
}
|
||||
|
||||
return false, true
|
||||
}
|
||||
|
||||
// UntrackExpose removes an active expose session from tracking.
|
||||
func (t *exposeTracker) UntrackExpose(peerID, domain string) {
|
||||
t.activeExposes.Delete(exposeKey(peerID, domain))
|
||||
}
|
||||
|
||||
// CountPeerExposes returns the number of active expose sessions for a peer.
|
||||
func (t *exposeTracker) CountPeerExposes(peerID string) int {
|
||||
count := 0
|
||||
t.activeExposes.Range(func(_, val any) bool {
|
||||
if expose := val.(*trackedExpose); expose.peerID == peerID {
|
||||
count++
|
||||
}
|
||||
return true
|
||||
})
|
||||
return count
|
||||
}
|
||||
|
||||
// MaxExposesPerPeer returns the maximum number of concurrent exposes allowed per peer.
|
||||
func (t *exposeTracker) MaxExposesPerPeer() int {
|
||||
return maxExposesPerPeer
|
||||
}
|
||||
|
||||
// RenewTrackedExpose updates the in-memory lastRenewed timestamp for a tracked expose.
|
||||
// Returns false if the expose is not tracked or is being reaped.
|
||||
func (t *exposeTracker) RenewTrackedExpose(peerID, domain string) bool {
|
||||
key := exposeKey(peerID, domain)
|
||||
val, ok := t.activeExposes.Load(key)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
expose := val.(*trackedExpose)
|
||||
expose.mu.Lock()
|
||||
if expose.expiring {
|
||||
expose.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
expose.lastRenewed = time.Now()
|
||||
expose.mu.Unlock()
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// StopTrackedExpose removes an active expose session from tracking.
|
||||
// Returns false if the expose was not tracked.
|
||||
func (t *exposeTracker) StopTrackedExpose(peerID, domain string) bool {
|
||||
key := exposeKey(peerID, domain)
|
||||
_, ok := t.activeExposes.LoadAndDelete(key)
|
||||
return ok
|
||||
}
|
||||
|
||||
// StartExposeReaper starts a background goroutine that reaps expired expose sessions.
|
||||
func (t *exposeTracker) StartExposeReaper(ctx context.Context) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(exposeReapInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
t.reapExpiredExposes()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (t *exposeTracker) reapExpiredExposes() {
|
||||
t.activeExposes.Range(func(key, val any) bool {
|
||||
expose := val.(*trackedExpose)
|
||||
expose.mu.Lock()
|
||||
expired := time.Since(expose.lastRenewed) > exposeTTL
|
||||
if expired {
|
||||
expose.expiring = true
|
||||
}
|
||||
expose.mu.Unlock()
|
||||
|
||||
if !expired {
|
||||
return true
|
||||
}
|
||||
|
||||
log.Infof("reaping expired expose session for peer %s, domain %s", expose.peerID, expose.domain)
|
||||
|
||||
err := t.manager.deleteServiceFromPeer(context.Background(), expose.accountID, expose.peerID, expose.domain, true)
|
||||
|
||||
s, _ := status.FromError(err)
|
||||
|
||||
switch {
|
||||
case err == nil:
|
||||
t.activeExposes.Delete(key)
|
||||
case s.ErrorType == status.NotFound:
|
||||
log.Debugf("service %s was already deleted", expose.domain)
|
||||
default:
|
||||
log.Errorf("failed to delete expired peer-exposed service for domain %s: %v", expose.domain, err)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
@@ -1,256 +0,0 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
)
|
||||
|
||||
func TestExposeKey(t *testing.T) {
|
||||
assert.Equal(t, "peer1:example.com", exposeKey("peer1", "example.com"))
|
||||
assert.Equal(t, "peer2:other.com", exposeKey("peer2", "other.com"))
|
||||
assert.NotEqual(t, exposeKey("peer1", "a.com"), exposeKey("peer1", "b.com"))
|
||||
}
|
||||
|
||||
func TestTrackExposeIfAllowed(t *testing.T) {
|
||||
t.Run("first track succeeds", func(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
|
||||
assert.False(t, alreadyTracked, "first track should not be duplicate")
|
||||
assert.True(t, ok, "first track should be allowed")
|
||||
})
|
||||
|
||||
t.Run("duplicate track detected", func(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
|
||||
|
||||
alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
|
||||
assert.True(t, alreadyTracked, "second track should be duplicate")
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("rejects when at limit", func(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
for i := range maxExposesPerPeer {
|
||||
_, ok := tracker.TrackExposeIfAllowed("peer1", "domain-"+string(rune('a'+i))+".com", "acct1")
|
||||
assert.True(t, ok, "track %d should be allowed", i)
|
||||
}
|
||||
|
||||
alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "over-limit.com", "acct1")
|
||||
assert.False(t, alreadyTracked)
|
||||
assert.False(t, ok, "should reject when at limit")
|
||||
})
|
||||
|
||||
t.Run("other peer unaffected by limit", func(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
for i := range maxExposesPerPeer {
|
||||
tracker.TrackExposeIfAllowed("peer1", "domain-"+string(rune('a'+i))+".com", "acct1")
|
||||
}
|
||||
|
||||
_, ok := tracker.TrackExposeIfAllowed("peer2", "a.com", "acct1")
|
||||
assert.True(t, ok, "other peer should still be within limit")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUntrackExpose(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
|
||||
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
|
||||
assert.Equal(t, 1, tracker.CountPeerExposes("peer1"))
|
||||
|
||||
tracker.UntrackExpose("peer1", "a.com")
|
||||
assert.Equal(t, 0, tracker.CountPeerExposes("peer1"))
|
||||
}
|
||||
|
||||
func TestCountPeerExposes(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
|
||||
assert.Equal(t, 0, tracker.CountPeerExposes("peer1"))
|
||||
|
||||
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
|
||||
tracker.TrackExposeIfAllowed("peer1", "b.com", "acct1")
|
||||
tracker.TrackExposeIfAllowed("peer2", "a.com", "acct1")
|
||||
|
||||
assert.Equal(t, 2, tracker.CountPeerExposes("peer1"), "peer1 should have 2 exposes")
|
||||
assert.Equal(t, 1, tracker.CountPeerExposes("peer2"), "peer2 should have 1 expose")
|
||||
assert.Equal(t, 0, tracker.CountPeerExposes("peer3"), "peer3 should have 0 exposes")
|
||||
}
|
||||
|
||||
func TestMaxExposesPerPeer(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
assert.Equal(t, maxExposesPerPeer, tracker.MaxExposesPerPeer())
|
||||
}
|
||||
|
||||
func TestRenewTrackedExpose(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
|
||||
found := tracker.RenewTrackedExpose("peer1", "a.com")
|
||||
assert.False(t, found, "should not find untracked expose")
|
||||
|
||||
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
|
||||
|
||||
found = tracker.RenewTrackedExpose("peer1", "a.com")
|
||||
assert.True(t, found, "should find tracked expose")
|
||||
}
|
||||
|
||||
func TestRenewTrackedExpose_RejectsExpiring(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
|
||||
|
||||
// Simulate reaper marking the expose as expiring
|
||||
key := exposeKey("peer1", "a.com")
|
||||
val, _ := tracker.activeExposes.Load(key)
|
||||
expose := val.(*trackedExpose)
|
||||
expose.mu.Lock()
|
||||
expose.expiring = true
|
||||
expose.mu.Unlock()
|
||||
|
||||
found := tracker.RenewTrackedExpose("peer1", "a.com")
|
||||
assert.False(t, found, "should reject renewal when expiring")
|
||||
}
|
||||
|
||||
func TestReapExpiredExposes(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
tracker := mgr.exposeTracker
|
||||
|
||||
ctx := context.Background()
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Manually expire the tracked entry
|
||||
key := exposeKey(testPeerID, resp.Domain)
|
||||
val, _ := tracker.activeExposes.Load(key)
|
||||
expose := val.(*trackedExpose)
|
||||
expose.mu.Lock()
|
||||
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
|
||||
expose.mu.Unlock()
|
||||
|
||||
// Add an active (non-expired) tracking entry
|
||||
tracker.activeExposes.Store(exposeKey("peer1", "active.com"), &trackedExpose{
|
||||
domain: "active.com",
|
||||
accountID: testAccountID,
|
||||
peerID: "peer1",
|
||||
lastRenewed: time.Now(),
|
||||
})
|
||||
|
||||
tracker.reapExpiredExposes()
|
||||
|
||||
_, exists := tracker.activeExposes.Load(key)
|
||||
assert.False(t, exists, "expired expose should be removed")
|
||||
|
||||
_, exists = tracker.activeExposes.Load(exposeKey("peer1", "active.com"))
|
||||
assert.True(t, exists, "active expose should remain")
|
||||
}
|
||||
|
||||
func TestReapExpiredExposes_SetsExpiringFlag(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
tracker := mgr.exposeTracker
|
||||
|
||||
ctx := context.Background()
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
key := exposeKey(testPeerID, resp.Domain)
|
||||
val, _ := tracker.activeExposes.Load(key)
|
||||
expose := val.(*trackedExpose)
|
||||
|
||||
// Expire it
|
||||
expose.mu.Lock()
|
||||
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
|
||||
expose.mu.Unlock()
|
||||
|
||||
// Renew should succeed before reaping
|
||||
assert.True(t, tracker.RenewTrackedExpose(testPeerID, resp.Domain), "renew should succeed before reaper runs")
|
||||
|
||||
// Re-expire and reap
|
||||
expose.mu.Lock()
|
||||
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
|
||||
expose.mu.Unlock()
|
||||
|
||||
tracker.reapExpiredExposes()
|
||||
|
||||
// Entry is deleted, renew returns false
|
||||
assert.False(t, tracker.RenewTrackedExpose(testPeerID, resp.Domain), "renew should fail after reap")
|
||||
}
|
||||
|
||||
func TestConcurrentTrackAndCount(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
tracker := mgr.exposeTracker
|
||||
ctx := context.Background()
|
||||
|
||||
for i := range 5 {
|
||||
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080 + i,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Manually expire all tracked entries
|
||||
tracker.activeExposes.Range(func(_, val any) bool {
|
||||
expose := val.(*trackedExpose)
|
||||
expose.mu.Lock()
|
||||
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
|
||||
expose.mu.Unlock()
|
||||
return true
|
||||
})
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tracker.reapExpiredExposes()
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tracker.CountPeerExposes(testPeerID)
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
assert.Equal(t, 0, tracker.CountPeerExposes(testPeerID), "all expired exposes should be reaped")
|
||||
}
|
||||
|
||||
func TestTrackedExposeMutexProtectsLastRenewed(t *testing.T) {
|
||||
expose := &trackedExpose{
|
||||
lastRenewed: time.Now().Add(-1 * time.Hour),
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range 100 {
|
||||
expose.mu.Lock()
|
||||
expose.lastRenewed = time.Now()
|
||||
expose.mu.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range 100 {
|
||||
expose.mu.Lock()
|
||||
_ = time.Since(expose.lastRenewed)
|
||||
expose.mu.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
expose.mu.Lock()
|
||||
require.False(t, expose.lastRenewed.IsZero(), "lastRenewed should not be zero after concurrent access")
|
||||
expose.mu.Unlock()
|
||||
}
|
||||
@@ -1,860 +0,0 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const unknownHostPlaceholder = "unknown"
|
||||
|
||||
// ClusterDeriver derives the proxy cluster from a domain.
|
||||
type ClusterDeriver interface {
|
||||
DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error)
|
||||
GetClusterDomains() []string
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
proxyController proxy.Controller
|
||||
clusterDeriver ClusterDeriver
|
||||
exposeTracker *exposeTracker
|
||||
}
|
||||
|
||||
// NewManager creates a new service manager.
|
||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, clusterDeriver ClusterDeriver) *Manager {
|
||||
mgr := &Manager{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
proxyController: proxyController,
|
||||
clusterDeriver: clusterDeriver,
|
||||
}
|
||||
mgr.exposeTracker = &exposeTracker{manager: mgr}
|
||||
return mgr
|
||||
}
|
||||
|
||||
// StartExposeReaper delegates to the expose tracker.
|
||||
func (m *Manager) StartExposeReaper(ctx context.Context) {
|
||||
m.exposeTracker.StartExposeReaper(ctx)
|
||||
}
|
||||
|
||||
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return services, nil
|
||||
}
|
||||
|
||||
func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *service.Service) error {
|
||||
for _, target := range s.Targets {
|
||||
switch target.TargetType {
|
||||
case service.TargetTypePeer:
|
||||
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, s.ID, err)
|
||||
target.Host = unknownHostPlaceholder
|
||||
continue
|
||||
}
|
||||
target.Host = peer.IP.String()
|
||||
case service.TargetTypeHost:
|
||||
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
|
||||
target.Host = unknownHostPlaceholder
|
||||
continue
|
||||
}
|
||||
target.Host = resource.Prefix.Addr().String()
|
||||
case service.TargetTypeDomain:
|
||||
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
|
||||
target.Host = unknownHostPlaceholder
|
||||
continue
|
||||
}
|
||||
target.Host = resource.Domain
|
||||
case service.TargetTypeSubnet:
|
||||
// For subnets we do not do any lookups on the resource
|
||||
default:
|
||||
return fmt.Errorf("unknown target type: %s", target.TargetType)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get service: %w", err)
|
||||
}
|
||||
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.persistNewService(ctx, accountID, s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta())
|
||||
|
||||
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||
}
|
||||
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID string, service *service.Service) error {
|
||||
if m.clusterDeriver != nil {
|
||||
proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain)
|
||||
return status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
|
||||
}
|
||||
service.ProxyCluster = proxyCluster
|
||||
}
|
||||
|
||||
service.AccountID = accountID
|
||||
service.InitNewRecord()
|
||||
|
||||
if err := service.Auth.HashSecrets(); err != nil {
|
||||
return fmt.Errorf("hash secrets: %w", err)
|
||||
}
|
||||
|
||||
keyPair, err := sessionkey.GenerateKeyPair()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate session keys: %w", err)
|
||||
}
|
||||
service.SessionPrivateKey = keyPair.PrivateKey
|
||||
service.SessionPublicKey = keyPair.PublicKey
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("failed to create service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
|
||||
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||
return fmt.Errorf("failed to check existing service: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if existingService != nil && existingService.ID != excludeServiceID {
|
||||
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", domain)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := service.Auth.HashSecrets(); err != nil {
|
||||
return nil, fmt.Errorf("hash secrets: %w", err)
|
||||
}
|
||||
|
||||
updateInfo, err := m.persistServiceUpdate(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceUpdated, service.EventMeta())
|
||||
|
||||
if err := m.replaceHostByLookup(ctx, accountID, service); err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
type serviceUpdateInfo struct {
|
||||
oldCluster string
|
||||
domainChanged bool
|
||||
serviceEnabledChanged bool
|
||||
}
|
||||
|
||||
func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *service.Service) (*serviceUpdateInfo, error) {
|
||||
var updateInfo serviceUpdateInfo
|
||||
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateInfo.oldCluster = existingService.ProxyCluster
|
||||
updateInfo.domainChanged = existingService.Domain != service.Domain
|
||||
|
||||
if updateInfo.domainChanged {
|
||||
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
service.ProxyCluster = existingService.ProxyCluster
|
||||
}
|
||||
|
||||
m.preserveExistingAuthSecrets(service, existingService)
|
||||
m.preserveServiceMetadata(service, existingService)
|
||||
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
|
||||
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.UpdateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("update service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return &updateInfo, err
|
||||
}
|
||||
|
||||
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.clusterDeriver != nil {
|
||||
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
|
||||
} else {
|
||||
service.ProxyCluster = newCluster
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) {
|
||||
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
|
||||
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
|
||||
service.Auth.PasswordAuth.Password == "" {
|
||||
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
|
||||
}
|
||||
|
||||
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
|
||||
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
|
||||
service.Auth.PinAuth.Pin == "" {
|
||||
service.Auth.PinAuth = existingService.Auth.PinAuth
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) {
|
||||
service.Meta = existingService.Meta
|
||||
service.SessionPrivateKey = existingService.SessionPrivateKey
|
||||
service.SessionPublicKey = existingService.SessionPublicKey
|
||||
}
|
||||
|
||||
func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) {
|
||||
oidcCfg := m.proxyController.GetOIDCValidationConfig()
|
||||
|
||||
switch {
|
||||
case updateInfo.domainChanged && updateInfo.oldCluster != s.ProxyCluster:
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), updateInfo.oldCluster)
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
|
||||
case !s.Enabled && updateInfo.serviceEnabledChanged:
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), s.ProxyCluster)
|
||||
case s.Enabled && updateInfo.serviceEnabledChanged:
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
|
||||
default:
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster)
|
||||
}
|
||||
}
|
||||
|
||||
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
|
||||
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*service.Target) error {
|
||||
for _, target := range targets {
|
||||
switch target.TargetType {
|
||||
case service.TargetTypePeer:
|
||||
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
|
||||
}
|
||||
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
|
||||
}
|
||||
case service.TargetTypeHost, service.TargetTypeSubnet, service.TargetTypeDomain:
|
||||
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
|
||||
}
|
||||
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var s *service.Service
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.DeleteServiceTargets(ctx, accountID, serviceID); err != nil {
|
||||
return fmt.Errorf("failed to delete targets: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
|
||||
return fmt.Errorf("failed to delete service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s.Source == service.SourceEphemeral {
|
||||
m.exposeTracker.UntrackExpose(s.SourcePeer, s.Domain)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, s.EventMeta())
|
||||
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var services []*service.Service
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
services, err = transaction.GetAccountServices(ctx, store.LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, svc := range services {
|
||||
if err = transaction.DeleteService(ctx, accountID, svc.ID); err != nil {
|
||||
return fmt.Errorf("failed to delete service: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oidcCfg := m.proxyController.GetOIDCValidationConfig()
|
||||
|
||||
for _, svc := range services {
|
||||
if svc.Source == service.SourceEphemeral {
|
||||
m.exposeTracker.UntrackExpose(svc.SourcePeer, svc.Domain)
|
||||
}
|
||||
m.accountManager.StoreEvent(ctx, userID, svc.ID, accountID, activity.ServiceDeleted, svc.EventMeta())
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", oidcCfg), svc.ProxyCluster)
|
||||
}
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
|
||||
// Call this when receiving a gRPC notification that the certificate was issued.
|
||||
func (m *Manager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get service: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
service.Meta.CertificateIssuedAt = &now
|
||||
|
||||
if err = transaction.UpdateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("failed to update service certificate timestamp: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
|
||||
func (m *Manager) SetStatus(ctx context.Context, accountID, serviceID string, status service.Status) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get service: %w", err)
|
||||
}
|
||||
|
||||
service.Meta.Status = string(status)
|
||||
|
||||
if err = transaction.UpdateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("failed to update service status: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string) error {
|
||||
s, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get service: %w", err)
|
||||
}
|
||||
|
||||
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||
}
|
||||
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get services: %w", err)
|
||||
}
|
||||
|
||||
for _, s := range services {
|
||||
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||
}
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
|
||||
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
err = m.replaceHostByLookup(ctx, service.AccountID, service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return services, nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*service.Service, error) {
|
||||
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get service: %w", err)
|
||||
}
|
||||
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
|
||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return services, nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
|
||||
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
return "", nil
|
||||
}
|
||||
return "", fmt.Errorf("failed to get service target by resource ID: %w", err)
|
||||
}
|
||||
|
||||
if target == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return target.ServiceID, nil
|
||||
}
|
||||
|
||||
// validateExposePermission checks whether the peer is allowed to use the expose feature.
|
||||
// It verifies the account has peer expose enabled and that the peer belongs to an allowed group.
|
||||
func (m *Manager) validateExposePermission(ctx context.Context, accountID, peerID string) error {
|
||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
|
||||
return status.Errorf(status.Internal, "get account settings: %v", err)
|
||||
}
|
||||
|
||||
if !settings.PeerExposeEnabled {
|
||||
return status.Errorf(status.PermissionDenied, "peer expose is not enabled for this account")
|
||||
}
|
||||
|
||||
if len(settings.PeerExposeGroups) == 0 {
|
||||
return status.Errorf(status.PermissionDenied, "no group is set for peer expose")
|
||||
}
|
||||
|
||||
peerGroupIDs, err := m.store.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peer group IDs: %v", err)
|
||||
return status.Errorf(status.Internal, "get peer groups: %v", err)
|
||||
}
|
||||
|
||||
for _, pg := range peerGroupIDs {
|
||||
if slices.Contains(settings.PeerExposeGroups, pg) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return status.Errorf(status.PermissionDenied, "peer is not in an allowed expose group")
|
||||
}
|
||||
|
||||
// CreateServiceFromPeer creates a service initiated by a peer expose request.
|
||||
// It validates the request, checks expose permissions, enforces the per-peer limit,
|
||||
// creates the service, and tracks it for TTL-based reaping.
|
||||
func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
|
||||
if err := req.Validate(); err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "validate expose request: %v", err)
|
||||
}
|
||||
|
||||
if err := m.validateExposePermission(ctx, accountID, peerID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serviceName, err := service.GenerateExposeName(req.NamePrefix)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "generate service name: %v", err)
|
||||
}
|
||||
|
||||
svc := req.ToService(accountID, peerID, serviceName)
|
||||
svc.Source = service.SourceEphemeral
|
||||
|
||||
if svc.Domain == "" {
|
||||
domain, err := m.buildRandomDomain(svc.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build random domain for service %s: %w", svc.Name, err)
|
||||
}
|
||||
svc.Domain = domain
|
||||
}
|
||||
|
||||
if svc.Auth.BearerAuth != nil && svc.Auth.BearerAuth.Enabled {
|
||||
groupIDs, err := m.getGroupIDsFromNames(ctx, accountID, svc.Auth.BearerAuth.DistributionGroups)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group ids for service %s: %w", svc.Name, err)
|
||||
}
|
||||
svc.Auth.BearerAuth.DistributionGroups = groupIDs
|
||||
}
|
||||
|
||||
if err := m.initializeServiceForCreate(ctx, accountID, svc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
svc.Meta.LastRenewedAt = &now
|
||||
svc.SourcePeer = peerID
|
||||
|
||||
if err := m.persistNewService(ctx, accountID, svc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
alreadyTracked, allowed := m.exposeTracker.TrackExposeIfAllowed(peerID, svc.Domain, accountID)
|
||||
if alreadyTracked {
|
||||
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, svc.Domain, false); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to delete duplicate expose service for domain %s: %v", svc.Domain, err)
|
||||
}
|
||||
return nil, status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
|
||||
}
|
||||
if !allowed {
|
||||
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, svc.Domain, false); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to delete service after limit exceeded for domain %s: %v", svc.Domain, err)
|
||||
}
|
||||
return nil, status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
|
||||
}
|
||||
|
||||
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
|
||||
m.accountManager.StoreEvent(ctx, peerID, svc.ID, accountID, activity.PeerServiceExposed, meta)
|
||||
|
||||
if err := m.replaceHostByLookup(ctx, accountID, svc); err != nil {
|
||||
return nil, fmt.Errorf("replace host by lookup for service %s: %w", svc.ID, err)
|
||||
}
|
||||
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return &service.ExposeServiceResponse{
|
||||
ServiceName: svc.Name,
|
||||
ServiceURL: "https://" + svc.Domain,
|
||||
Domain: svc.Domain,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, groupNames []string) ([]string, error) {
|
||||
if len(groupNames) == 0 {
|
||||
return []string{}, fmt.Errorf("no group names provided")
|
||||
}
|
||||
groupIDs := make([]string, 0, len(groupNames))
|
||||
for _, groupName := range groupNames {
|
||||
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err)
|
||||
}
|
||||
groupIDs = append(groupIDs, g.ID)
|
||||
}
|
||||
return groupIDs, nil
|
||||
}
|
||||
|
||||
func (m *Manager) buildRandomDomain(name string) (string, error) {
|
||||
if m.clusterDeriver == nil {
|
||||
return "", fmt.Errorf("unable to get random domain")
|
||||
}
|
||||
clusterDomains := m.clusterDeriver.GetClusterDomains()
|
||||
if len(clusterDomains) == 0 {
|
||||
return "", fmt.Errorf("no cluster domains found for service %s", name)
|
||||
}
|
||||
index := rand.IntN(len(clusterDomains))
|
||||
domain := name + "." + clusterDomains[index]
|
||||
return domain, nil
|
||||
}
|
||||
|
||||
// RenewServiceFromPeer renews the in-memory TTL tracker for the peer's expose session.
|
||||
// Returns an error if the expose is not actively tracked.
|
||||
func (m *Manager) RenewServiceFromPeer(_ context.Context, _, peerID, domain string) error {
|
||||
if !m.exposeTracker.RenewTrackedExpose(peerID, domain) {
|
||||
return status.Errorf(status.NotFound, "no active expose session for domain %s", domain)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopServiceFromPeer stops a peer's active expose session by untracking and deleting the service.
|
||||
func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
|
||||
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err)
|
||||
return err
|
||||
}
|
||||
|
||||
if !m.exposeTracker.StopTrackedExpose(peerID, domain) {
|
||||
log.WithContext(ctx).Warnf("expose tracker entry for domain %s already removed; service was deleted", domain)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteServiceFromPeer deletes a peer-initiated service identified by domain.
|
||||
// When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed.
|
||||
func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error {
|
||||
svc, err := m.lookupPeerService(ctx, accountID, peerID, domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
activityCode := activity.PeerServiceUnexposed
|
||||
if expired {
|
||||
activityCode = activity.PeerServiceExposeExpired
|
||||
}
|
||||
return m.deletePeerService(ctx, accountID, peerID, svc.ID, activityCode)
|
||||
}
|
||||
|
||||
// lookupPeerService finds a peer-initiated service by domain and validates ownership.
|
||||
func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) {
|
||||
svc, err := m.store.GetServiceByDomain(ctx, accountID, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if svc.Source != service.SourceEphemeral {
|
||||
return nil, status.Errorf(status.PermissionDenied, "cannot operate on API-created service via peer expose")
|
||||
}
|
||||
|
||||
if svc.SourcePeer != peerID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "cannot operate on service exposed by another peer")
|
||||
}
|
||||
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error {
|
||||
var svc *service.Service
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
svc, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if svc.Source != service.SourceEphemeral {
|
||||
return status.Errorf(status.PermissionDenied, "cannot delete API-created service via peer expose")
|
||||
}
|
||||
|
||||
if svc.SourcePeer != peerID {
|
||||
return status.Errorf(status.PermissionDenied, "cannot delete service exposed by another peer")
|
||||
}
|
||||
|
||||
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
|
||||
return fmt.Errorf("delete service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get peer %s for event metadata: %v", peerID, err)
|
||||
peer = nil
|
||||
}
|
||||
|
||||
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
|
||||
|
||||
m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activityCode, meta)
|
||||
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func addPeerInfoToEventMeta(meta map[string]any, peer *nbpeer.Peer) map[string]any {
|
||||
if peer == nil {
|
||||
return meta
|
||||
}
|
||||
meta["peer_name"] = peer.Name
|
||||
if peer.IP != nil {
|
||||
meta["peer_ip"] = peer.IP.String()
|
||||
}
|
||||
return meta
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -94,7 +94,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
||||
|
||||
func (s *BaseServer) APIHandler() http.Handler {
|
||||
return Create(s, func() http.Handler {
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -134,7 +134,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
if s.Config.HttpConfig.LetsEncryptDomain != "" {
|
||||
certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create certificate service: %v", err)
|
||||
log.Fatalf("failed to create certificate manager: %v", err)
|
||||
}
|
||||
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
|
||||
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
|
||||
@@ -152,11 +152,6 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create management server: %v", err)
|
||||
}
|
||||
serviceMgr := s.ServiceManager()
|
||||
srv.SetReverseProxyManager(serviceMgr)
|
||||
if serviceMgr != nil {
|
||||
serviceMgr.StartExposeReaper(context.Background())
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
|
||||
|
||||
mgmtProto.RegisterProxyServiceServer(gRPCAPIHandler, s.ReverseProxyGRPCServer())
|
||||
@@ -168,10 +163,9 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
|
||||
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
||||
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager())
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
proxyService.SetServiceManager(s.ServiceManager())
|
||||
proxyService.SetProxyController(s.ServiceProxyController())
|
||||
proxyService.SetProxyManager(s.ReverseProxyManager())
|
||||
})
|
||||
return proxyService
|
||||
})
|
||||
@@ -194,10 +188,7 @@ func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig {
|
||||
|
||||
func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
|
||||
return Create(s, func() *nbgrpc.OneTimeTokenStore {
|
||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 5*time.Minute, 10*time.Minute, 100)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create proxy token store: %v", err)
|
||||
}
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute)
|
||||
log.Info("One-time token store initialized for proxy authentication")
|
||||
return tokenStore
|
||||
})
|
||||
@@ -206,11 +197,6 @@ func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
|
||||
func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
|
||||
return Create(s, func() accesslogs.Manager {
|
||||
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager())
|
||||
accessLogManager.StartPeriodicCleanup(
|
||||
context.Background(),
|
||||
s.Config.ReverseProxy.AccessLogRetentionDays,
|
||||
s.Config.ReverseProxy.AccessLogCleanupIntervalHours,
|
||||
)
|
||||
return accessLogManager
|
||||
})
|
||||
}
|
||||
|
||||
@@ -200,13 +200,4 @@ type ReverseProxy struct {
|
||||
// request headers if the peer's address falls within one of these
|
||||
// trusted IP prefixes.
|
||||
TrustedPeers []netip.Prefix
|
||||
|
||||
// AccessLogRetentionDays specifies the number of days to retain access logs.
|
||||
// Logs older than this duration will be automatically deleted during cleanup.
|
||||
// A value of 0 or negative means logs are kept indefinitely (no cleanup).
|
||||
AccessLogRetentionDays int
|
||||
|
||||
// AccessLogCleanupIntervalHours specifies how often (in hours) to run the cleanup routine.
|
||||
// Defaults to 24 hours if not set or set to 0.
|
||||
AccessLogCleanupIntervalHours int
|
||||
}
|
||||
|
||||
@@ -6,8 +6,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
@@ -108,16 +106,6 @@ func (s *BaseServer) NetworkMapController() network_map.Controller {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ServiceProxyController() proxy.Controller {
|
||||
return Create(s, func() proxy.Controller {
|
||||
controller, err := proxymanager.NewGRPCController(s.ReverseProxyGRPCServer(), s.Metrics().GetMeter())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create service proxy controller: %v", err)
|
||||
}
|
||||
return controller
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer {
|
||||
return Create(s, func() *server.AccountRequestBuffer {
|
||||
return server.NewAccountRequestBuffer(context.Background(), s.Store())
|
||||
|
||||
@@ -8,11 +8,9 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
@@ -101,11 +99,11 @@ func (s *BaseServer) AccountManager() account.Manager {
|
||||
return Create(s, func() account.Manager {
|
||||
accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create account service: %v", err)
|
||||
log.Fatalf("failed to create account manager: %v", err)
|
||||
}
|
||||
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
accountManager.SetServiceManager(s.ServiceManager())
|
||||
accountManager.SetServiceManager(s.ReverseProxyManager())
|
||||
})
|
||||
|
||||
return accountManager
|
||||
@@ -116,28 +114,28 @@ func (s *BaseServer) IdpManager() idp.Manager {
|
||||
return Create(s, func() idp.Manager {
|
||||
var idpManager idp.Manager
|
||||
var err error
|
||||
// Use embedded IdP service if embedded Dex is configured and enabled.
|
||||
// Use embedded IdP manager if embedded Dex is configured and enabled.
|
||||
// Legacy IdpManager won't be used anymore even if configured.
|
||||
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
|
||||
idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create embedded IDP service: %v", err)
|
||||
log.Fatalf("failed to create embedded IDP manager: %v", err)
|
||||
}
|
||||
return idpManager
|
||||
}
|
||||
|
||||
// Fall back to external IdP service
|
||||
// Fall back to external IdP manager
|
||||
if s.Config.IdpManagerConfig != nil {
|
||||
idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create IDP service: %v", err)
|
||||
log.Fatalf("failed to create IDP manager: %v", err)
|
||||
}
|
||||
}
|
||||
return idpManager
|
||||
})
|
||||
}
|
||||
|
||||
// OAuthConfigProvider is only relevant when we have an embedded IdP service. Otherwise must be nil
|
||||
// OAuthConfigProvider is only relevant when we have an embedded IdP manager. Otherwise must be nil
|
||||
func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
|
||||
if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled {
|
||||
return nil
|
||||
@@ -164,7 +162,7 @@ func (s *BaseServer) GroupsManager() groups.Manager {
|
||||
|
||||
func (s *BaseServer) ResourcesManager() resources.Manager {
|
||||
return Create(s, func() resources.Manager {
|
||||
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ServiceManager())
|
||||
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ReverseProxyManager())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -192,25 +190,15 @@ func (s *BaseServer) RecordsManager() records.Manager {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ServiceManager() service.Manager {
|
||||
return Create(s, func() service.Manager {
|
||||
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ReverseProxyDomainManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ProxyManager() proxy.Manager {
|
||||
return Create(s, func() proxy.Manager {
|
||||
manager, err := proxymanager.NewManager(s.Store(), s.Metrics().GetMeter())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create proxy manager: %v", err)
|
||||
}
|
||||
return manager
|
||||
func (s *BaseServer) ReverseProxyManager() reverseproxy.Manager {
|
||||
return Create(s, func() reverseproxy.Manager {
|
||||
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), s.ReverseProxyDomainManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
||||
return Create(s, func() *manager.Manager {
|
||||
m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager())
|
||||
m := manager.NewManager(s.Store(), s.ReverseProxyGRPCServer(), s.PermissionsManager())
|
||||
return &m
|
||||
})
|
||||
}
|
||||
|
||||
@@ -157,7 +157,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
||||
|
||||
// Eagerly create the gRPC server so that all AfterInit hooks are registered
|
||||
// before we iterate them. Lazy creation after the loop would miss hooks
|
||||
// registered during GRPCServer() construction (e.g., SetServiceManager).
|
||||
// registered during GRPCServer() construction (e.g., SetProxyManager).
|
||||
s.GRPCServer()
|
||||
|
||||
for _, fn := range s.afterInit {
|
||||
|
||||
@@ -1,192 +0,0 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
pb "github.com/golang/protobuf/proto" // nolint
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
internalStatus "github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// CreateExpose handles a peer request to create a new expose service.
|
||||
func (s *Server) CreateExpose(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
exposeReq := &proto.ExposeServiceRequest{}
|
||||
peerKey, err := s.parseRequest(ctx, req, exposeReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accountID, peer, err := s.authenticateExposePeer(ctx, peerKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||
|
||||
reverseProxyMgr := s.getReverseProxyManager()
|
||||
if reverseProxyMgr == nil {
|
||||
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
|
||||
}
|
||||
|
||||
created, err := reverseProxyMgr.CreateServiceFromPeer(ctx, accountID, peer.ID, &rpservice.ExposeServiceRequest{
|
||||
NamePrefix: exposeReq.NamePrefix,
|
||||
Port: int(exposeReq.Port),
|
||||
Protocol: exposeProtocolToString(exposeReq.Protocol),
|
||||
Domain: exposeReq.Domain,
|
||||
Pin: exposeReq.Pin,
|
||||
Password: exposeReq.Password,
|
||||
UserGroups: exposeReq.UserGroups,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, mapExposeError(ctx, err)
|
||||
}
|
||||
|
||||
return s.encryptResponse(peerKey, &proto.ExposeServiceResponse{
|
||||
ServiceName: created.ServiceName,
|
||||
ServiceUrl: created.ServiceURL,
|
||||
Domain: created.Domain,
|
||||
})
|
||||
}
|
||||
|
||||
// RenewExpose extends the TTL of an active expose session.
|
||||
func (s *Server) RenewExpose(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
renewReq := &proto.RenewExposeRequest{}
|
||||
peerKey, err := s.parseRequest(ctx, req, renewReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accountID, peer, err := s.authenticateExposePeer(ctx, peerKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reverseProxyMgr := s.getReverseProxyManager()
|
||||
if reverseProxyMgr == nil {
|
||||
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
|
||||
}
|
||||
|
||||
if err := reverseProxyMgr.RenewServiceFromPeer(ctx, accountID, peer.ID, renewReq.Domain); err != nil {
|
||||
return nil, mapExposeError(ctx, err)
|
||||
}
|
||||
|
||||
return s.encryptResponse(peerKey, &proto.RenewExposeResponse{})
|
||||
}
|
||||
|
||||
// StopExpose terminates an active expose session.
|
||||
func (s *Server) StopExpose(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
stopReq := &proto.StopExposeRequest{}
|
||||
peerKey, err := s.parseRequest(ctx, req, stopReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accountID, peer, err := s.authenticateExposePeer(ctx, peerKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reverseProxyMgr := s.getReverseProxyManager()
|
||||
if reverseProxyMgr == nil {
|
||||
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
|
||||
}
|
||||
|
||||
if err := reverseProxyMgr.StopServiceFromPeer(ctx, accountID, peer.ID, stopReq.Domain); err != nil {
|
||||
return nil, mapExposeError(ctx, err)
|
||||
}
|
||||
|
||||
return s.encryptResponse(peerKey, &proto.StopExposeResponse{})
|
||||
}
|
||||
|
||||
func mapExposeError(ctx context.Context, err error) error {
|
||||
s, ok := internalStatus.FromError(err)
|
||||
if !ok {
|
||||
log.WithContext(ctx).Errorf("expose service error: %v", err)
|
||||
return status.Errorf(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
switch s.Type() {
|
||||
case internalStatus.InvalidArgument:
|
||||
return status.Errorf(codes.InvalidArgument, "%s", s.Message)
|
||||
case internalStatus.PermissionDenied:
|
||||
return status.Errorf(codes.PermissionDenied, "%s", s.Message)
|
||||
case internalStatus.NotFound:
|
||||
return status.Errorf(codes.NotFound, "%s", s.Message)
|
||||
case internalStatus.AlreadyExists:
|
||||
return status.Errorf(codes.AlreadyExists, "%s", s.Message)
|
||||
case internalStatus.PreconditionFailed:
|
||||
return status.Errorf(codes.ResourceExhausted, "%s", s.Message)
|
||||
default:
|
||||
log.WithContext(ctx).Errorf("expose service error: %v", err)
|
||||
return status.Errorf(codes.Internal, "internal error")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) encryptResponse(peerKey wgtypes.Key, msg pb.Message) (*proto.EncryptedMessage, error) {
|
||||
wgKey, err := s.secretsManager.GetWGKey()
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, wgKey, msg)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "encrypt response")
|
||||
}
|
||||
|
||||
return &proto.EncryptedMessage{
|
||||
WgPubKey: wgKey.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) authenticateExposePeer(ctx context.Context, peerKey wgtypes.Key) (string, *nbpeer.Peer, error) {
|
||||
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
|
||||
if err != nil {
|
||||
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
|
||||
return "", nil, status.Errorf(codes.PermissionDenied, "peer is not registered")
|
||||
}
|
||||
return "", nil, status.Errorf(codes.Internal, "lookup account for peer")
|
||||
}
|
||||
|
||||
peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerKey.String())
|
||||
if err != nil {
|
||||
return "", nil, status.Errorf(codes.PermissionDenied, "peer is not registered")
|
||||
}
|
||||
|
||||
return accountID, peer, nil
|
||||
}
|
||||
|
||||
func (s *Server) getReverseProxyManager() rpservice.Manager {
|
||||
s.reverseProxyMu.RLock()
|
||||
defer s.reverseProxyMu.RUnlock()
|
||||
return s.reverseProxyManager
|
||||
}
|
||||
|
||||
// SetReverseProxyManager sets the reverse proxy manager on the server.
|
||||
func (s *Server) SetReverseProxyManager(mgr rpservice.Manager) {
|
||||
s.reverseProxyMu.Lock()
|
||||
defer s.reverseProxyMu.Unlock()
|
||||
s.reverseProxyManager = mgr
|
||||
}
|
||||
|
||||
func exposeProtocolToString(p proto.ExposeProtocol) string {
|
||||
switch p {
|
||||
case proto.ExposeProtocol_EXPOSE_HTTP:
|
||||
return "http"
|
||||
case proto.ExposeProtocol_EXPOSE_HTTPS:
|
||||
return "https"
|
||||
default:
|
||||
return "http"
|
||||
}
|
||||
}
|
||||
@@ -1,23 +1,28 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/eko/gocache/lib/v4/cache"
|
||||
"github.com/eko/gocache/lib/v4/store"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
)
|
||||
|
||||
// OneTimeTokenStore manages short-lived, single-use authentication tokens
|
||||
// for proxy-to-management RPC authentication. Tokens are generated when
|
||||
// a service is created and must be used exactly once by the proxy
|
||||
// to authenticate a subsequent RPC call.
|
||||
type OneTimeTokenStore struct {
|
||||
tokens map[string]*tokenMetadata
|
||||
mu sync.RWMutex
|
||||
cleanup *time.Ticker
|
||||
cleanupDone chan struct{}
|
||||
}
|
||||
|
||||
// tokenMetadata stores information about a one-time token
|
||||
type tokenMetadata struct {
|
||||
ServiceID string
|
||||
AccountID string
|
||||
@@ -25,24 +30,20 @@ type tokenMetadata struct {
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// OneTimeTokenStore manages single-use authentication tokens for proxy-to-management RPC.
|
||||
// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var.
|
||||
type OneTimeTokenStore struct {
|
||||
cache *cache.Cache[string]
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewOneTimeTokenStore creates a token store with automatic backend selection
|
||||
func NewOneTimeTokenStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*OneTimeTokenStore, error) {
|
||||
cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cache store: %w", err)
|
||||
// NewOneTimeTokenStore creates a new token store with automatic cleanup
|
||||
// of expired tokens. The cleanupInterval determines how often expired
|
||||
// tokens are removed from memory.
|
||||
func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
|
||||
store := &OneTimeTokenStore{
|
||||
tokens: make(map[string]*tokenMetadata),
|
||||
cleanup: time.NewTicker(cleanupInterval),
|
||||
cleanupDone: make(chan struct{}),
|
||||
}
|
||||
|
||||
return &OneTimeTokenStore{
|
||||
cache: cache.New[string](cacheStore),
|
||||
ctx: ctx,
|
||||
}, nil
|
||||
// Start background cleanup goroutine
|
||||
go store.cleanupExpired()
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
// GenerateToken creates a new cryptographically secure one-time token
|
||||
@@ -51,30 +52,25 @@ func NewOneTimeTokenStore(ctx context.Context, maxTimeout, cleanupInterval time.
|
||||
//
|
||||
// Returns the generated token string or an error if random generation fails.
|
||||
func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) {
|
||||
// Generate 32 bytes (256 bits) of cryptographically secure random data
|
||||
randomBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate random token: %w", err)
|
||||
}
|
||||
|
||||
// Encode as URL-safe base64 for easy transmission in gRPC
|
||||
token := base64.URLEncoding.EncodeToString(randomBytes)
|
||||
hashedToken := hashToken(token)
|
||||
|
||||
metadata := &tokenMetadata{
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.tokens[token] = &tokenMetadata{
|
||||
ServiceID: serviceID,
|
||||
AccountID: accountID,
|
||||
ExpiresAt: time.Now().Add(ttl),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
metadataJSON, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to serialize token metadata: %w", err)
|
||||
}
|
||||
|
||||
if err := s.cache.Set(s.ctx, hashedToken, string(metadataJSON), store.WithExpiration(ttl)); err != nil {
|
||||
return "", fmt.Errorf("failed to store token: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)",
|
||||
serviceID, accountID, ttl)
|
||||
|
||||
@@ -92,45 +88,80 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.
|
||||
// - Account ID doesn't match
|
||||
// - Reverse proxy ID doesn't match
|
||||
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error {
|
||||
hashedToken := hashToken(token)
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
metadataJSON, err := s.cache.Get(s.ctx, hashedToken)
|
||||
if err != nil {
|
||||
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)", serviceID, accountID)
|
||||
metadata, exists := s.tokens[token]
|
||||
if !exists {
|
||||
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)",
|
||||
serviceID, accountID)
|
||||
return fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
metadata := &tokenMetadata{}
|
||||
if err := json.Unmarshal([]byte(metadataJSON), metadata); err != nil {
|
||||
log.Warnf("Token validation failed: failed to unmarshal metadata (proxy: %s, account: %s): %v", serviceID, accountID, err)
|
||||
return fmt.Errorf("invalid token metadata")
|
||||
}
|
||||
|
||||
// Check expiration
|
||||
if time.Now().After(metadata.ExpiresAt) {
|
||||
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)", serviceID, accountID)
|
||||
delete(s.tokens, token)
|
||||
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)",
|
||||
serviceID, accountID)
|
||||
return fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
// Validate account ID using constant-time comparison (prevents timing attacks)
|
||||
if subtle.ConstantTimeCompare([]byte(metadata.AccountID), []byte(accountID)) != 1 {
|
||||
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)", metadata.AccountID, accountID)
|
||||
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)",
|
||||
metadata.AccountID, accountID)
|
||||
return fmt.Errorf("account ID mismatch")
|
||||
}
|
||||
|
||||
// Validate service ID using constant-time comparison
|
||||
if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 {
|
||||
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)", metadata.ServiceID, serviceID)
|
||||
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)",
|
||||
metadata.ServiceID, serviceID)
|
||||
return fmt.Errorf("service ID mismatch")
|
||||
}
|
||||
|
||||
if err := s.cache.Delete(s.ctx, hashedToken); err != nil {
|
||||
log.Warnf("Token deletion warning (proxy: %s, account: %s): %v", serviceID, accountID, err)
|
||||
}
|
||||
// Delete token immediately to enforce single-use
|
||||
delete(s.tokens, token)
|
||||
|
||||
log.Infof("Token validated and consumed for proxy %s in account %s", serviceID, accountID)
|
||||
log.Infof("Token validated and consumed for proxy %s in account %s",
|
||||
serviceID, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func hashToken(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
// cleanupExpired removes expired tokens in the background to prevent memory leaks
|
||||
func (s *OneTimeTokenStore) cleanupExpired() {
|
||||
for {
|
||||
select {
|
||||
case <-s.cleanup.C:
|
||||
s.mu.Lock()
|
||||
now := time.Now()
|
||||
removed := 0
|
||||
for token, metadata := range s.tokens {
|
||||
if now.After(metadata.ExpiresAt) {
|
||||
delete(s.tokens, token)
|
||||
removed++
|
||||
}
|
||||
}
|
||||
if removed > 0 {
|
||||
log.Debugf("Cleaned up %d expired one-time tokens", removed)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
case <-s.cleanupDone:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the cleanup goroutine and releases resources
|
||||
func (s *OneTimeTokenStore) Close() {
|
||||
s.cleanup.Stop()
|
||||
close(s.cleanupDone)
|
||||
}
|
||||
|
||||
// GetTokenCount returns the current number of tokens in the store (for debugging/metrics)
|
||||
func (s *OneTimeTokenStore) GetTokenCount() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return len(s.tokens)
|
||||
}
|
||||
|
||||
@@ -24,9 +24,8 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
@@ -59,17 +58,17 @@ type ProxyServiceServer struct {
|
||||
// Map of connected proxies: proxy_id -> proxy connection
|
||||
connectedProxies sync.Map
|
||||
|
||||
// Map of cluster address -> set of proxy IDs
|
||||
clusterProxies sync.Map
|
||||
|
||||
// Channel for broadcasting reverse proxy updates to all proxies
|
||||
updatesChan chan *proto.ProxyMapping
|
||||
|
||||
// Manager for access logs
|
||||
accessLogManager accesslogs.Manager
|
||||
|
||||
// Manager for reverse proxy operations
|
||||
serviceManager rpservice.Manager
|
||||
|
||||
// ProxyController for service updates and cluster management
|
||||
proxyController proxy.Controller
|
||||
|
||||
// Manager for proxy connections
|
||||
proxyManager proxy.Manager
|
||||
reverseProxyManager reverseproxy.Manager
|
||||
|
||||
// Manager for peers
|
||||
peersManager peers.Manager
|
||||
@@ -102,25 +101,24 @@ type proxyConnection struct {
|
||||
proxyID string
|
||||
address string
|
||||
stream proto.ProxyService_GetMappingUpdateServer
|
||||
sendChan chan *proto.GetMappingUpdateResponse
|
||||
sendChan chan *proto.ProxyMapping
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewProxyServiceServer creates a new proxy service server.
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager) *ProxyServiceServer {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s := &ProxyServiceServer{
|
||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
||||
accessLogManager: accessLogMgr,
|
||||
oidcConfig: oidcConfig,
|
||||
tokenStore: tokenStore,
|
||||
peersManager: peersManager,
|
||||
usersManager: usersManager,
|
||||
proxyManager: proxyMgr,
|
||||
pkceCleanupCancel: cancel,
|
||||
}
|
||||
go s.cleanupPKCEVerifiers(ctx)
|
||||
go s.cleanupStaleProxies(ctx)
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -144,33 +142,13 @@ func (s *ProxyServiceServer) cleanupPKCEVerifiers(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupStaleProxies periodically removes proxies that haven't sent heartbeat in 10 minutes
|
||||
func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.proxyManager.CleanupStale(ctx, 10*time.Minute); err != nil {
|
||||
log.WithContext(ctx).Debugf("Failed to cleanup stale proxies: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops background goroutines.
|
||||
func (s *ProxyServiceServer) Close() {
|
||||
s.pkceCleanupCancel()
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) {
|
||||
s.serviceManager = manager
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller) {
|
||||
s.proxyController = proxyController
|
||||
func (s *ProxyServiceServer) SetProxyManager(manager reverseproxy.Manager) {
|
||||
s.reverseProxyManager = manager
|
||||
}
|
||||
|
||||
// GetMappingUpdate handles the control stream with proxy clients
|
||||
@@ -199,21 +177,13 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
proxyID: proxyID,
|
||||
address: proxyAddress,
|
||||
stream: stream,
|
||||
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
|
||||
sendChan: make(chan *proto.ProxyMapping, 100),
|
||||
ctx: connCtx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
||||
}
|
||||
|
||||
// Register proxy in database
|
||||
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err)
|
||||
}
|
||||
|
||||
s.addToCluster(conn.address, proxyID)
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"address": proxyAddress,
|
||||
@@ -221,15 +191,8 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
"total_proxies": len(s.GetConnectedProxies()),
|
||||
}).Info("Proxy registered in cluster")
|
||||
defer func() {
|
||||
if err := s.proxyManager.Disconnect(context.Background(), proxyID); err != nil {
|
||||
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
|
||||
}
|
||||
|
||||
s.connectedProxies.Delete(proxyID)
|
||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
|
||||
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
|
||||
}
|
||||
|
||||
s.removeFromCluster(conn.address, proxyID)
|
||||
cancel()
|
||||
log.Infof("Proxy %s disconnected", proxyID)
|
||||
}()
|
||||
@@ -241,9 +204,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
errChan := make(chan error, 2)
|
||||
go s.sender(conn, errChan)
|
||||
|
||||
// Start heartbeat goroutine
|
||||
go s.heartbeat(connCtx, proxyID)
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
|
||||
@@ -252,27 +212,10 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeat updates the proxy's last_seen timestamp every minute
|
||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID string) {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := s.proxyManager.Heartbeat(ctx, proxyID); err != nil {
|
||||
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
|
||||
// Only services matching the proxy's cluster address are sent.
|
||||
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get services from store: %w", err)
|
||||
}
|
||||
@@ -281,7 +224,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
return fmt.Errorf("proxy address is invalid")
|
||||
}
|
||||
|
||||
var filtered []*rpservice.Service
|
||||
var filtered []*reverseproxy.Service
|
||||
for _, service := range services {
|
||||
if !service.Enabled {
|
||||
continue
|
||||
@@ -316,7 +259,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
service.ToProtoMapping(
|
||||
rpservice.Create, // Initial snapshot, all records are "new" for the proxy.
|
||||
reverseproxy.Create, // Initial snapshot, all records are "new" for the proxy.
|
||||
token,
|
||||
s.GetOIDCValidationConfig(),
|
||||
),
|
||||
@@ -345,7 +288,7 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
|
||||
for {
|
||||
select {
|
||||
case msg := <-conn.sendChan:
|
||||
if err := conn.stream.Send(msg); err != nil {
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{Mapping: []*proto.ProxyMapping{msg}}); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
@@ -396,7 +339,7 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA
|
||||
// Management should call this when services are created/updated/removed.
|
||||
// For create/update operations a unique one-time auth token is generated per
|
||||
// proxy so that every replica can independently authenticate with management.
|
||||
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) {
|
||||
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.ProxyMapping) {
|
||||
log.Debugf("Broadcasting service update to all connected proxy servers")
|
||||
s.connectedProxies.Range(func(key, value interface{}) bool {
|
||||
conn := value.(*proxyConnection)
|
||||
@@ -406,7 +349,7 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
|
||||
}
|
||||
select {
|
||||
case conn.sendChan <- msg:
|
||||
log.Debugf("Sent service update to proxy server %s", conn.proxyID)
|
||||
log.Debugf("Sent service update with id %s to proxy server %s", update.Id, conn.proxyID)
|
||||
default:
|
||||
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
|
||||
}
|
||||
@@ -450,75 +393,81 @@ func (s *ProxyServiceServer) GetConnectedProxyURLs() []string {
|
||||
return urls
|
||||
}
|
||||
|
||||
// addToCluster registers a proxy in a cluster.
|
||||
func (s *ProxyServiceServer) addToCluster(clusterAddr, proxyID string) {
|
||||
if clusterAddr == "" {
|
||||
return
|
||||
}
|
||||
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
|
||||
proxySet.(*sync.Map).Store(proxyID, struct{}{})
|
||||
log.Debugf("Added proxy %s to cluster %s", proxyID, clusterAddr)
|
||||
}
|
||||
|
||||
// removeFromCluster removes a proxy from a cluster.
|
||||
func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) {
|
||||
if clusterAddr == "" {
|
||||
return
|
||||
}
|
||||
if proxySet, ok := s.clusterProxies.Load(clusterAddr); ok {
|
||||
proxySet.(*sync.Map).Delete(proxyID)
|
||||
log.Debugf("Removed proxy %s from cluster %s", proxyID, clusterAddr)
|
||||
}
|
||||
}
|
||||
|
||||
// SendServiceUpdateToCluster sends a service update to all proxy servers in a specific cluster.
|
||||
// If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility).
|
||||
// For create/update operations a unique one-time auth token is generated per
|
||||
// proxy so that every replica can independently authenticate with management.
|
||||
func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, update *proto.ProxyMapping, clusterAddr string) {
|
||||
updateResponse := &proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{update},
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) {
|
||||
if clusterAddr == "" {
|
||||
s.SendServiceUpdate(updateResponse)
|
||||
s.SendServiceUpdate(update)
|
||||
return
|
||||
}
|
||||
|
||||
if s.proxyController == nil {
|
||||
log.WithContext(ctx).Debugf("ProxyController not set, cannot send to cluster %s", clusterAddr)
|
||||
return
|
||||
}
|
||||
|
||||
proxyIDs := s.proxyController.GetProxiesForCluster(clusterAddr)
|
||||
if len(proxyIDs) == 0 {
|
||||
log.WithContext(ctx).Debugf("No proxies connected for cluster %s", clusterAddr)
|
||||
proxySet, ok := s.clusterProxies.Load(clusterAddr)
|
||||
if !ok {
|
||||
log.Debugf("No proxies connected for cluster %s", clusterAddr)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("Sending service update to cluster %s", clusterAddr)
|
||||
for _, proxyID := range proxyIDs {
|
||||
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
|
||||
proxyID := key.(string)
|
||||
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
|
||||
conn := connVal.(*proxyConnection)
|
||||
msg := s.perProxyMessage(updateResponse, proxyID)
|
||||
msg := s.perProxyMessage(update, proxyID)
|
||||
if msg == nil {
|
||||
continue
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case conn.sendChan <- msg:
|
||||
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
|
||||
log.Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
|
||||
default:
|
||||
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
|
||||
log.Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// perProxyMessage returns a copy of update with a fresh one-time token for
|
||||
// create/update operations. For delete operations the original mapping is
|
||||
// used unchanged because proxies do not need to authenticate for removal.
|
||||
// create/update operations. For delete operations the original message is
|
||||
// returned unchanged because proxies do not need to authenticate for removal.
|
||||
// Returns nil if token generation fails (the proxy should be skipped).
|
||||
func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse {
|
||||
resp := make([]*proto.ProxyMapping, 0, len(update.Mapping))
|
||||
for _, mapping := range update.Mapping {
|
||||
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
|
||||
resp = append(resp, mapping)
|
||||
continue
|
||||
}
|
||||
|
||||
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
msg := shallowCloneMapping(mapping)
|
||||
msg.AuthToken = token
|
||||
resp = append(resp, msg)
|
||||
func (s *ProxyServiceServer) perProxyMessage(update *proto.ProxyMapping, proxyID string) *proto.ProxyMapping {
|
||||
if update.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED || update.AccountId == "" {
|
||||
return update
|
||||
}
|
||||
|
||||
return &proto.GetMappingUpdateResponse{
|
||||
Mapping: resp,
|
||||
token, err := s.tokenStore.GenerateToken(update.AccountId, update.Id, 5*time.Minute)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
msg := shallowCloneMapping(update)
|
||||
msg.AuthToken = token
|
||||
return msg
|
||||
}
|
||||
|
||||
// shallowCloneMapping creates a shallow copy of a ProxyMapping, reusing the
|
||||
@@ -537,8 +486,35 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
}
|
||||
}
|
||||
|
||||
// GetAvailableClusters returns information about all connected proxy clusters.
|
||||
func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo {
|
||||
clusterCounts := make(map[string]int)
|
||||
s.clusterProxies.Range(func(key, value interface{}) bool {
|
||||
clusterAddr := key.(string)
|
||||
proxySet := value.(*sync.Map)
|
||||
count := 0
|
||||
proxySet.Range(func(_, _ interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
if count > 0 {
|
||||
clusterCounts[clusterAddr] = count
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
clusters := make([]ClusterInfo, 0, len(clusterCounts))
|
||||
for addr, count := range clusterCounts {
|
||||
clusters = append(clusters, ClusterInfo{
|
||||
Address: addr,
|
||||
ConnectedProxies: count,
|
||||
})
|
||||
}
|
||||
return clusters
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
|
||||
service, err := s.reverseProxyManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "get service from store: %v", err)
|
||||
@@ -557,7 +533,7 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *rpservice.Service) (bool, string, proxyauth.Method) {
|
||||
func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *reverseproxy.Service) (bool, string, proxyauth.Method) {
|
||||
switch v := req.GetRequest().(type) {
|
||||
case *proto.AuthenticateRequest_Pin:
|
||||
return s.authenticatePIN(ctx, req.GetId(), v, service.Auth.PinAuth)
|
||||
@@ -568,7 +544,7 @@ func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *rpservice.PINAuthConfig) (bool, string, proxyauth.Method) {
|
||||
func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *reverseproxy.PINAuthConfig) (bool, string, proxyauth.Method) {
|
||||
if auth == nil || !auth.Enabled {
|
||||
log.WithContext(ctx).Debugf("PIN authentication attempted but not enabled for service %s", serviceID)
|
||||
return false, "", ""
|
||||
@@ -582,7 +558,7 @@ func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID stri
|
||||
return true, "pin-user", proxyauth.MethodPIN
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *rpservice.PasswordAuthConfig) (bool, string, proxyauth.Method) {
|
||||
func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *reverseproxy.PasswordAuthConfig) (bool, string, proxyauth.Method) {
|
||||
if auth == nil || !auth.Enabled {
|
||||
log.WithContext(ctx).Debugf("password authentication attempted but not enabled for service %s", serviceID)
|
||||
return false, "", ""
|
||||
@@ -604,7 +580,7 @@ func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId string, method proxyauth.Method) (string, error) {
|
||||
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *reverseproxy.Service, userId string, method proxyauth.Method) (string, error) {
|
||||
if !authenticated || service.SessionPrivateKey == "" {
|
||||
return "", nil
|
||||
}
|
||||
@@ -644,7 +620,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
|
||||
}
|
||||
|
||||
if certificateIssued {
|
||||
if err := s.serviceManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil {
|
||||
if err := s.reverseProxyManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil {
|
||||
log.WithContext(ctx).WithError(err).Error("failed to set certificate issued timestamp")
|
||||
return nil, status.Errorf(codes.Internal, "update certificate timestamp: %v", err)
|
||||
}
|
||||
@@ -656,7 +632,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
|
||||
|
||||
internalStatus := protoStatusToInternal(protoStatus)
|
||||
|
||||
if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
|
||||
if err := s.reverseProxyManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
|
||||
log.WithContext(ctx).WithError(err).Error("failed to update service status")
|
||||
return nil, status.Errorf(codes.Internal, "update service status: %v", err)
|
||||
}
|
||||
@@ -671,22 +647,22 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
|
||||
}
|
||||
|
||||
// protoStatusToInternal maps proto status to internal status
|
||||
func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status {
|
||||
func protoStatusToInternal(protoStatus proto.ProxyStatus) reverseproxy.ProxyStatus {
|
||||
switch protoStatus {
|
||||
case proto.ProxyStatus_PROXY_STATUS_PENDING:
|
||||
return rpservice.StatusPending
|
||||
return reverseproxy.StatusPending
|
||||
case proto.ProxyStatus_PROXY_STATUS_ACTIVE:
|
||||
return rpservice.StatusActive
|
||||
return reverseproxy.StatusActive
|
||||
case proto.ProxyStatus_PROXY_STATUS_TUNNEL_NOT_CREATED:
|
||||
return rpservice.StatusTunnelNotCreated
|
||||
return reverseproxy.StatusTunnelNotCreated
|
||||
case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_PENDING:
|
||||
return rpservice.StatusCertificatePending
|
||||
return reverseproxy.StatusCertificatePending
|
||||
case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_FAILED:
|
||||
return rpservice.StatusCertificateFailed
|
||||
return reverseproxy.StatusCertificateFailed
|
||||
case proto.ProxyStatus_PROXY_STATUS_ERROR:
|
||||
return rpservice.StatusError
|
||||
return reverseproxy.StatusError
|
||||
default:
|
||||
return rpservice.StatusError
|
||||
return reverseproxy.StatusError
|
||||
}
|
||||
}
|
||||
|
||||
@@ -751,7 +727,7 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
|
||||
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
|
||||
}
|
||||
// Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection.
|
||||
services, err := s.serviceManager.GetAccountServices(ctx, req.GetAccountId())
|
||||
services, err := s.reverseProxyManager.GetAccountServices(ctx, req.GetAccountId())
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account services: %v", err)
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "get account services: %v", err)
|
||||
@@ -814,8 +790,8 @@ func (s *ProxyServiceServer) GetOIDCConfig() ProxyOIDCConfig {
|
||||
|
||||
// GetOIDCValidationConfig returns the OIDC configuration for token validation
|
||||
// in the format needed by ToProtoMapping.
|
||||
func (s *ProxyServiceServer) GetOIDCValidationConfig() proxy.OIDCValidationConfig {
|
||||
return proxy.OIDCValidationConfig{
|
||||
func (s *ProxyServiceServer) GetOIDCValidationConfig() reverseproxy.OIDCValidationConfig {
|
||||
return reverseproxy.OIDCValidationConfig{
|
||||
Issuer: s.oidcConfig.Issuer,
|
||||
Audiences: []string{s.oidcConfig.Audience},
|
||||
KeysLocation: s.oidcConfig.KeysLocation,
|
||||
@@ -874,12 +850,12 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
|
||||
// GenerateSessionToken creates a signed session JWT for the given domain and user.
|
||||
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
|
||||
// Find the service by domain to get its signing key
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get services: %w", err)
|
||||
}
|
||||
|
||||
var service *rpservice.Service
|
||||
var service *reverseproxy.Service
|
||||
for _, svc := range services {
|
||||
if svc.Domain == domain {
|
||||
service = svc
|
||||
@@ -945,8 +921,8 @@ func (s *ProxyServiceServer) ValidateUserGroupAccess(ctx context.Context, domain
|
||||
return fmt.Errorf("user %s not in allowed groups for domain %s", user.Id, domain)
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) {
|
||||
services, err := s.serviceManager.GetAccountServices(ctx, accountID)
|
||||
func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
|
||||
services, err := s.reverseProxyManager.GetAccountServices(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get account services: %w", err)
|
||||
}
|
||||
@@ -1067,8 +1043,8 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*reverseproxy.Service, error) {
|
||||
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get services: %w", err)
|
||||
}
|
||||
@@ -1082,7 +1058,7 @@ func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain stri
|
||||
return nil, fmt.Errorf("service not found for domain: %s", domain)
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {
|
||||
func (s *ProxyServiceServer) checkGroupAccess(service *reverseproxy.Service, user *types.User) error {
|
||||
if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -71,14 +71,12 @@ func NewProxyAuthInterceptors(tokenStore proxyTokenStore) (grpc.UnaryServerInter
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
// token, err := interceptor.validateProxyToken(ctx)
|
||||
// if err != nil {
|
||||
// // Log auth failures explicitly; gRPC doesn't log these by default.
|
||||
// log.WithContext(ctx).Warnf("proxy auth failed: %v", err)
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
token := &types.ProxyAccessToken{ID: "dummy"}
|
||||
token, err := interceptor.validateProxyToken(ctx)
|
||||
if err != nil {
|
||||
// Log auth failures explicitly; gRPC doesn't log these by default.
|
||||
log.WithContext(ctx).Warnf("proxy auth failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, ProxyTokenContextKey, token)
|
||||
return handler(ctx, req)
|
||||
@@ -89,13 +87,12 @@ func NewProxyAuthInterceptors(tokenStore proxyTokenStore) (grpc.UnaryServerInter
|
||||
return handler(srv, ss)
|
||||
}
|
||||
|
||||
// token, err := interceptor.validateProxyToken(ss.Context())
|
||||
// if err != nil {
|
||||
// // Log auth failures explicitly; gRPC doesn't log these by default.
|
||||
// log.WithContext(ss.Context()).Warnf("proxy auth failed: %v", err)
|
||||
// return err
|
||||
// }
|
||||
token := &types.ProxyAccessToken{ID: "dummy"} // TODO: Implement token validation for streaming methods.
|
||||
token, err := interceptor.validateProxyToken(ss.Context())
|
||||
if err != nil {
|
||||
// Log auth failures explicitly; gRPC doesn't log these by default.
|
||||
log.WithContext(ss.Context()).Warnf("proxy auth failed: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := context.WithValue(ss.Context(), ProxyTokenContextKey, token)
|
||||
wrapped := &wrappedServerStream{
|
||||
|
||||
@@ -8,44 +8,40 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
type mockReverseProxyManager struct {
|
||||
proxiesByAccount map[string][]*service.Service
|
||||
proxiesByAccount map[string][]*reverseproxy.Service
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
|
||||
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.proxiesByAccount[accountID], nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
|
||||
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
|
||||
return []*service.Service{}, nil
|
||||
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
|
||||
return []*reverseproxy.Service{}, nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*service.Service, error) {
|
||||
return &service.Service{}, nil
|
||||
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.Service, error) {
|
||||
return &reverseproxy.Service{}, nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *service.Service) (*service.Service, error) {
|
||||
return &service.Service{}, nil
|
||||
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
return &reverseproxy.Service{}, nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *service.Service) (*service.Service, error) {
|
||||
return &service.Service{}, nil
|
||||
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
return &reverseproxy.Service{}, nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error {
|
||||
@@ -56,7 +52,7 @@ func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, ac
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status service.Status) error {
|
||||
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -68,28 +64,14 @@ func (m *mockReverseProxyManager) ReloadService(ctx context.Context, accountID,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*service.Service, error) {
|
||||
return &service.Service{}, nil
|
||||
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.Service, error) {
|
||||
return &reverseproxy.Service{}, nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
|
||||
return &service.ExposeServiceResponse{}, nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {}
|
||||
|
||||
type mockUsersManager struct {
|
||||
users map[string]*types.User
|
||||
err error
|
||||
@@ -111,7 +93,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
name string
|
||||
domain string
|
||||
userID string
|
||||
proxiesByAccount map[string][]*service.Service
|
||||
proxiesByAccount map[string][]*reverseproxy.Service
|
||||
users map[string]*types.User
|
||||
proxyErr error
|
||||
userErr error
|
||||
@@ -122,7 +104,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
name: "user not found",
|
||||
domain: "app.example.com",
|
||||
userID: "unknown-user",
|
||||
proxiesByAccount: map[string][]*service.Service{
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
|
||||
},
|
||||
users: map[string]*types.User{},
|
||||
@@ -133,7 +115,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
name: "proxy not found in user's account",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*service.Service{},
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{},
|
||||
users: map[string]*types.User{
|
||||
"user1": {Id: "user1", AccountID: "account1"},
|
||||
},
|
||||
@@ -144,7 +126,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
name: "proxy exists in different account - not accessible",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*service.Service{
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
"account2": {{Domain: "app.example.com", AccountID: "account2"}},
|
||||
},
|
||||
users: map[string]*types.User{
|
||||
@@ -157,8 +139,8 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
name: "no bearer auth configured - same account allows access",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*service.Service{
|
||||
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: service.AuthConfig{}}},
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}},
|
||||
},
|
||||
users: map[string]*types.User{
|
||||
"user1": {Id: "user1", AccountID: "account1"},
|
||||
@@ -169,12 +151,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
name: "bearer auth disabled - same account allows access",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*service.Service{
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
"account1": {{
|
||||
Domain: "app.example.com",
|
||||
AccountID: "account1",
|
||||
Auth: service.AuthConfig{
|
||||
BearerAuth: &service.BearerAuthConfig{Enabled: false},
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: false},
|
||||
},
|
||||
}},
|
||||
},
|
||||
@@ -187,12 +169,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
name: "bearer auth enabled but no groups configured - same account allows access",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*service.Service{
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
"account1": {{
|
||||
Domain: "app.example.com",
|
||||
AccountID: "account1",
|
||||
Auth: service.AuthConfig{
|
||||
BearerAuth: &service.BearerAuthConfig{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{},
|
||||
},
|
||||
@@ -208,12 +190,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
name: "user not in allowed groups",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*service.Service{
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
"account1": {{
|
||||
Domain: "app.example.com",
|
||||
AccountID: "account1",
|
||||
Auth: service.AuthConfig{
|
||||
BearerAuth: &service.BearerAuthConfig{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"group1", "group2"},
|
||||
},
|
||||
@@ -230,12 +212,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
name: "user in one of the allowed groups - allow access",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*service.Service{
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
"account1": {{
|
||||
Domain: "app.example.com",
|
||||
AccountID: "account1",
|
||||
Auth: service.AuthConfig{
|
||||
BearerAuth: &service.BearerAuthConfig{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"group1", "group2"},
|
||||
},
|
||||
@@ -251,12 +233,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
name: "user in all allowed groups - allow access",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*service.Service{
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
"account1": {{
|
||||
Domain: "app.example.com",
|
||||
AccountID: "account1",
|
||||
Auth: service.AuthConfig{
|
||||
BearerAuth: &service.BearerAuthConfig{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"group1", "group2"},
|
||||
},
|
||||
@@ -284,10 +266,10 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
name: "multiple proxies in account - finds correct one",
|
||||
domain: "app2.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*service.Service{
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
"account1": {
|
||||
{Domain: "app1.example.com", AccountID: "account1"},
|
||||
{Domain: "app2.example.com", AccountID: "account1", Auth: service.AuthConfig{}},
|
||||
{Domain: "app2.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}},
|
||||
{Domain: "app3.example.com", AccountID: "account1"},
|
||||
},
|
||||
},
|
||||
@@ -301,7 +283,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := &ProxyServiceServer{
|
||||
serviceManager: &mockReverseProxyManager{
|
||||
reverseProxyManager: &mockReverseProxyManager{
|
||||
proxiesByAccount: tt.proxiesByAccount,
|
||||
err: tt.proxyErr,
|
||||
},
|
||||
@@ -328,7 +310,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
|
||||
name string
|
||||
accountID string
|
||||
domain string
|
||||
proxiesByAccount map[string][]*service.Service
|
||||
proxiesByAccount map[string][]*reverseproxy.Service
|
||||
err error
|
||||
expectProxy bool
|
||||
expectErr bool
|
||||
@@ -337,7 +319,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
|
||||
name: "proxy found",
|
||||
accountID: "account1",
|
||||
domain: "app.example.com",
|
||||
proxiesByAccount: map[string][]*service.Service{
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
"account1": {
|
||||
{Domain: "other.example.com", AccountID: "account1"},
|
||||
{Domain: "app.example.com", AccountID: "account1"},
|
||||
@@ -350,7 +332,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
|
||||
name: "proxy not found in account",
|
||||
accountID: "account1",
|
||||
domain: "unknown.example.com",
|
||||
proxiesByAccount: map[string][]*service.Service{
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
|
||||
},
|
||||
expectProxy: false,
|
||||
@@ -360,7 +342,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
|
||||
name: "empty proxy list for account",
|
||||
accountID: "account1",
|
||||
domain: "app.example.com",
|
||||
proxiesByAccount: map[string][]*service.Service{},
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{},
|
||||
expectProxy: false,
|
||||
expectErr: true,
|
||||
},
|
||||
@@ -378,7 +360,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := &ProxyServiceServer{
|
||||
serviceManager: &mockReverseProxyManager{
|
||||
reverseProxyManager: &mockReverseProxyManager{
|
||||
proxiesByAccount: tt.proxiesByAccount,
|
||||
err: tt.err,
|
||||
},
|
||||
|
||||
@@ -1,77 +1,23 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"sync"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type testProxyController struct {
|
||||
mu sync.Mutex
|
||||
clusterProxies map[string]map[string]struct{}
|
||||
}
|
||||
|
||||
func newTestProxyController() *testProxyController {
|
||||
return &testProxyController{
|
||||
clusterProxies: make(map[string]map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *testProxyController) SendServiceUpdateToCluster(_ context.Context, _ string, _ *proto.ProxyMapping, _ string) {
|
||||
}
|
||||
|
||||
func (c *testProxyController) GetOIDCValidationConfig() proxy.OIDCValidationConfig {
|
||||
return proxy.OIDCValidationConfig{}
|
||||
}
|
||||
|
||||
func (c *testProxyController) RegisterProxyToCluster(_ context.Context, clusterAddr, proxyID string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if _, ok := c.clusterProxies[clusterAddr]; !ok {
|
||||
c.clusterProxies[clusterAddr] = make(map[string]struct{})
|
||||
}
|
||||
c.clusterProxies[clusterAddr][proxyID] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, clusterAddr, proxyID string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if proxies, ok := c.clusterProxies[clusterAddr]; ok {
|
||||
delete(proxies, proxyID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
proxies, ok := c.clusterProxies[clusterAddr]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
result := make([]string, 0, len(proxies))
|
||||
for id := range proxies {
|
||||
result = append(result, id)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// registerFakeProxy adds a fake proxy connection to the server's internal maps
|
||||
// and returns the channel where messages will be received.
|
||||
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.GetMappingUpdateResponse {
|
||||
ch := make(chan *proto.GetMappingUpdateResponse, 10)
|
||||
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.ProxyMapping {
|
||||
ch := make(chan *proto.ProxyMapping, 10)
|
||||
conn := &proxyConnection{
|
||||
proxyID: proxyID,
|
||||
address: clusterAddr,
|
||||
@@ -79,12 +25,13 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
|
||||
}
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
|
||||
_ = s.proxyController.RegisterProxyToCluster(context.Background(), clusterAddr, proxyID)
|
||||
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
|
||||
proxySet.(*sync.Map).Store(proxyID, struct{}{})
|
||||
|
||||
return ch
|
||||
}
|
||||
|
||||
func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse {
|
||||
func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
select {
|
||||
case msg := <-ch:
|
||||
return msg
|
||||
@@ -94,24 +41,24 @@ func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpda
|
||||
}
|
||||
|
||||
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
||||
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
tokenStore := NewOneTimeTokenStore(time.Hour)
|
||||
defer tokenStore.Close()
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
tokenStore: tokenStore,
|
||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
||||
}
|
||||
s.SetProxyController(newTestProxyController())
|
||||
|
||||
const cluster = "proxy.example.com"
|
||||
const numProxies = 3
|
||||
|
||||
channels := make([]chan *proto.GetMappingUpdateResponse, numProxies)
|
||||
channels := make([]chan *proto.ProxyMapping, numProxies)
|
||||
for i := range numProxies {
|
||||
id := "proxy-" + string(rune('a'+i))
|
||||
channels[i] = registerFakeProxy(s, id, cluster)
|
||||
}
|
||||
|
||||
mapping := &proto.ProxyMapping{
|
||||
update := &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "service-1",
|
||||
AccountId: "account-1",
|
||||
@@ -121,16 +68,14 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
s.SendServiceUpdateToCluster(context.Background(), mapping, cluster)
|
||||
s.SendServiceUpdateToCluster(update, cluster)
|
||||
|
||||
tokens := make([]string, numProxies)
|
||||
for i, ch := range channels {
|
||||
resp := drainChannel(ch)
|
||||
require.NotNil(t, resp, "proxy %d should receive a message", i)
|
||||
require.Len(t, resp.Mapping, 1, "proxy %d should receive exactly one mapping", i)
|
||||
msg := resp.Mapping[0]
|
||||
assert.Equal(t, mapping.Domain, msg.Domain)
|
||||
assert.Equal(t, mapping.Id, msg.Id)
|
||||
msg := drainChannel(ch)
|
||||
require.NotNil(t, msg, "proxy %d should receive a message", i)
|
||||
assert.Equal(t, update.Domain, msg.Domain)
|
||||
assert.Equal(t, update.Id, msg.Id)
|
||||
assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i)
|
||||
tokens[i] = msg.AuthToken
|
||||
}
|
||||
@@ -151,74 +96,66 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
||||
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
tokenStore := NewOneTimeTokenStore(time.Hour)
|
||||
defer tokenStore.Close()
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
tokenStore: tokenStore,
|
||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
||||
}
|
||||
s.SetProxyController(newTestProxyController())
|
||||
|
||||
const cluster = "proxy.example.com"
|
||||
ch1 := registerFakeProxy(s, "proxy-a", cluster)
|
||||
ch2 := registerFakeProxy(s, "proxy-b", cluster)
|
||||
|
||||
mapping := &proto.ProxyMapping{
|
||||
update := &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
|
||||
Id: "service-1",
|
||||
AccountId: "account-1",
|
||||
Domain: "test.example.com",
|
||||
}
|
||||
|
||||
s.SendServiceUpdateToCluster(context.Background(), mapping, cluster)
|
||||
s.SendServiceUpdateToCluster(update, cluster)
|
||||
|
||||
resp1 := drainChannel(ch1)
|
||||
resp2 := drainChannel(ch2)
|
||||
require.NotNil(t, resp1)
|
||||
require.NotNil(t, resp2)
|
||||
require.Len(t, resp1.Mapping, 1)
|
||||
require.Len(t, resp2.Mapping, 1)
|
||||
msg1 := drainChannel(ch1)
|
||||
msg2 := drainChannel(ch2)
|
||||
require.NotNil(t, msg1)
|
||||
require.NotNil(t, msg2)
|
||||
|
||||
// Delete operations should not generate tokens
|
||||
assert.Empty(t, resp1.Mapping[0].AuthToken)
|
||||
assert.Empty(t, resp2.Mapping[0].AuthToken)
|
||||
assert.Empty(t, msg1.AuthToken)
|
||||
assert.Empty(t, msg2.AuthToken)
|
||||
|
||||
// No tokens should have been created
|
||||
assert.Equal(t, 0, tokenStore.GetTokenCount())
|
||||
}
|
||||
|
||||
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
|
||||
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
tokenStore := NewOneTimeTokenStore(time.Hour)
|
||||
defer tokenStore.Close()
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
tokenStore: tokenStore,
|
||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
||||
}
|
||||
s.SetProxyController(newTestProxyController())
|
||||
|
||||
// Register proxies in different clusters (SendServiceUpdate broadcasts to all)
|
||||
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")
|
||||
ch2 := registerFakeProxy(s, "proxy-b", "cluster-b")
|
||||
|
||||
mapping := &proto.ProxyMapping{
|
||||
update := &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "service-1",
|
||||
AccountId: "account-1",
|
||||
Domain: "test.example.com",
|
||||
}
|
||||
|
||||
update := &proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{mapping},
|
||||
}
|
||||
|
||||
s.SendServiceUpdate(update)
|
||||
|
||||
resp1 := drainChannel(ch1)
|
||||
resp2 := drainChannel(ch2)
|
||||
require.NotNil(t, resp1)
|
||||
require.NotNil(t, resp2)
|
||||
require.Len(t, resp1.Mapping, 1)
|
||||
require.Len(t, resp2.Mapping, 1)
|
||||
|
||||
msg1 := resp1.Mapping[0]
|
||||
msg2 := resp2.Mapping[0]
|
||||
msg1 := drainChannel(ch1)
|
||||
msg2 := drainChannel(ch2)
|
||||
require.NotNil(t, msg1)
|
||||
require.NotNil(t, msg2)
|
||||
|
||||
assert.NotEmpty(t, msg1.AuthToken)
|
||||
assert.NotEmpty(t, msg2.AuthToken)
|
||||
|
||||
@@ -26,7 +26,6 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
@@ -81,9 +80,6 @@ type Server struct {
|
||||
syncSem atomic.Int32
|
||||
syncLimEnabled bool
|
||||
syncLim int32
|
||||
|
||||
reverseProxyManager rpservice.Manager
|
||||
reverseProxyMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -34,15 +34,11 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
|
||||
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
|
||||
serviceManager := &testValidateSessionServiceManager{store: testStore}
|
||||
proxyManager := &testValidateSessionProxyManager{store: testStore}
|
||||
usersManager := &testValidateSessionUsersManager{store: testStore}
|
||||
proxyManager := &testValidateSessionProxyManager{}
|
||||
|
||||
tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyService := NewProxyServiceServer(nil, tokenStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
|
||||
proxyService.SetServiceManager(serviceManager)
|
||||
proxyService := NewProxyServiceServer(nil, NewOneTimeTokenStore(time.Minute), ProxyOIDCConfig{}, nil, usersManager)
|
||||
proxyService.SetProxyManager(proxyManager)
|
||||
|
||||
createTestProxies(t, ctx, testStore)
|
||||
|
||||
@@ -58,7 +54,7 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
|
||||
|
||||
pubKey, privKey := generateSessionKeyPair(t)
|
||||
|
||||
testProxy := &service.Service{
|
||||
testProxy := &reverseproxy.Service{
|
||||
ID: "testProxyId",
|
||||
AccountID: "testAccountId",
|
||||
Name: "Test Proxy",
|
||||
@@ -66,15 +62,15 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
|
||||
Enabled: true,
|
||||
SessionPrivateKey: privKey,
|
||||
SessionPublicKey: pubKey,
|
||||
Auth: service.AuthConfig{
|
||||
BearerAuth: &service.BearerAuthConfig{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
require.NoError(t, testStore.CreateService(ctx, testProxy))
|
||||
|
||||
restrictedProxy := &service.Service{
|
||||
restrictedProxy := &reverseproxy.Service{
|
||||
ID: "restrictedProxyId",
|
||||
AccountID: "testAccountId",
|
||||
Name: "Restricted Proxy",
|
||||
@@ -82,8 +78,8 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
|
||||
Enabled: true,
|
||||
SessionPrivateKey: privKey,
|
||||
SessionPublicKey: pubKey,
|
||||
Auth: service.AuthConfig{
|
||||
BearerAuth: &service.BearerAuthConfig{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"allowedGroupId"},
|
||||
},
|
||||
@@ -200,7 +196,7 @@ func TestValidateSession_ProxyNotFound(t *testing.T) {
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.False(t, resp.Valid, "Unknown proxy should be denied")
|
||||
assert.Equal(t, "service_not_found", resp.DeniedReason)
|
||||
assert.Equal(t, "proxy_not_found", resp.DeniedReason)
|
||||
}
|
||||
|
||||
func TestValidateSession_InvalidToken(t *testing.T) {
|
||||
@@ -243,102 +239,62 @@ func TestValidateSession_MissingToken(t *testing.T) {
|
||||
assert.Contains(t, resp.DeniedReason, "missing")
|
||||
}
|
||||
|
||||
type testValidateSessionServiceManager struct {
|
||||
type testValidateSessionProxyManager struct {
|
||||
store store.Store
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*service.Service, error) {
|
||||
func (m *testValidateSessionProxyManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) GetService(_ context.Context, _, _, _ string) (*service.Service, error) {
|
||||
func (m *testValidateSessionProxyManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
|
||||
func (m *testValidateSessionProxyManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
|
||||
func (m *testValidateSessionProxyManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) DeleteService(_ context.Context, _, _, _ string) error {
|
||||
func (m *testValidateSessionProxyManager) DeleteService(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) DeleteAllServices(_ context.Context, _, _ string) error {
|
||||
func (m *testValidateSessionProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
|
||||
func (m *testValidateSessionProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) SetStatus(_ context.Context, _, _ string, _ service.Status) error {
|
||||
func (m *testValidateSessionProxyManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
|
||||
func (m *testValidateSessionProxyManager) ReloadService(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) ReloadService(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
|
||||
func (m *testValidateSessionProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
||||
return m.store.GetServices(ctx, store.LockingStrengthNone)
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*service.Service, error) {
|
||||
func (m *testValidateSessionProxyManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
|
||||
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
|
||||
func (m *testValidateSessionProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
||||
func (m *testValidateSessionProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {}
|
||||
|
||||
type testValidateSessionProxyManager struct{}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Context) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type testValidateSessionUsersManager struct {
|
||||
store store.Store
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
@@ -83,9 +83,9 @@ type DefaultAccountManager struct {
|
||||
|
||||
requestBuffer *AccountRequestBuffer
|
||||
|
||||
proxyController port_forwarding.Controller
|
||||
settingsManager settings.Manager
|
||||
serviceManager service.Manager
|
||||
proxyController port_forwarding.Controller
|
||||
settingsManager settings.Manager
|
||||
reverseProxyManager reverseproxy.Manager
|
||||
|
||||
// config contains the management server configuration
|
||||
config *nbconfig.Config
|
||||
@@ -115,8 +115,8 @@ type DefaultAccountManager struct {
|
||||
|
||||
var _ account.Manager = (*DefaultAccountManager)(nil)
|
||||
|
||||
func (am *DefaultAccountManager) SetServiceManager(serviceManager service.Manager) {
|
||||
am.serviceManager = serviceManager
|
||||
func (am *DefaultAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
|
||||
am.reverseProxyManager = serviceManager
|
||||
}
|
||||
|
||||
func isUniqueConstraintError(err error) bool {
|
||||
@@ -376,7 +376,6 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
am.handleAutoUpdateVersionSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
am.handlePeerExposeSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -395,7 +394,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta)
|
||||
}
|
||||
if reloadReverseProxy {
|
||||
if err = am.serviceManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
|
||||
if err = am.reverseProxyManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err)
|
||||
}
|
||||
}
|
||||
@@ -493,21 +492,6 @@ func (am *DefaultAccountManager) handleAutoUpdateVersionSettings(ctx context.Con
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handlePeerExposeSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
|
||||
oldEnabled := oldSettings.PeerExposeEnabled
|
||||
newEnabled := newSettings.PeerExposeEnabled
|
||||
|
||||
if oldEnabled == newEnabled {
|
||||
return
|
||||
}
|
||||
|
||||
event := activity.AccountPeerExposeEnabled
|
||||
if !newEnabled {
|
||||
event = activity.AccountPeerExposeDisabled
|
||||
}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error {
|
||||
if newSettings.PeerInactivityExpirationEnabled {
|
||||
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
||||
@@ -730,11 +714,6 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
||||
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
err = am.serviceManager.DeleteAllServices(ctx, accountID, userID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to delete service %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
for _, otherUser := range account.Users {
|
||||
if otherUser.Id == userID {
|
||||
continue
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
package account
|
||||
|
||||
//go:generate go run github.com/golang/mock/mockgen -package account -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@@ -63,11 +61,11 @@ type Manager interface {
|
||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
|
||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
|
||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
|
||||
AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
|
||||
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
||||
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
|
||||
@@ -142,5 +140,5 @@ type Manager interface {
|
||||
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
|
||||
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
|
||||
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
|
||||
SetServiceManager(serviceManager service.Manager)
|
||||
SetServiceManager(serviceManager reverseproxy.Manager)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -19,7 +19,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@@ -28,13 +27,10 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/cache"
|
||||
@@ -1806,12 +1802,12 @@ func TestAccount_Copy(t *testing.T) {
|
||||
Address: "172.12.6.1/24",
|
||||
},
|
||||
},
|
||||
Services: []*service.Service{
|
||||
Services: []*reverseproxy.Service{
|
||||
{
|
||||
ID: "service1",
|
||||
Name: "test-service",
|
||||
AccountID: "account1",
|
||||
Targets: []*service.Target{},
|
||||
Targets: []*reverseproxy.Target{},
|
||||
},
|
||||
},
|
||||
NetworkMapCache: &types.NetworkMapBuilder{},
|
||||
@@ -3116,12 +3112,6 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
|
||||
proxyManager := proxy.NewMockManager(ctrl)
|
||||
proxyManager.EXPECT().
|
||||
CleanupStale(gomock.Any(), gomock.Any()).
|
||||
Return(nil).
|
||||
AnyTimes()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
@@ -3132,12 +3122,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
|
||||
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, nil))
|
||||
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, nil, nil))
|
||||
|
||||
return manager, updateManager, nil
|
||||
}
|
||||
|
||||
@@ -208,18 +208,6 @@ const (
|
||||
ServiceUpdated Activity = 109
|
||||
ServiceDeleted Activity = 110
|
||||
|
||||
// PeerServiceExposed indicates that a peer exposed a service via the reverse proxy
|
||||
PeerServiceExposed Activity = 111
|
||||
// PeerServiceUnexposed indicates that a peer-exposed service was removed
|
||||
PeerServiceUnexposed Activity = 112
|
||||
// PeerServiceExposeExpired indicates that a peer-exposed service was removed due to TTL expiration
|
||||
PeerServiceExposeExpired Activity = 113
|
||||
|
||||
// AccountPeerExposeEnabled indicates that a user enabled peer expose for the account
|
||||
AccountPeerExposeEnabled Activity = 114
|
||||
// AccountPeerExposeDisabled indicates that a user disabled peer expose for the account
|
||||
AccountPeerExposeDisabled Activity = 115
|
||||
|
||||
AccountDeleted Activity = 99999
|
||||
)
|
||||
|
||||
@@ -357,13 +345,6 @@ var activityMap = map[Activity]Code{
|
||||
ServiceCreated: {"Service created", "service.create"},
|
||||
ServiceUpdated: {"Service updated", "service.update"},
|
||||
ServiceDeleted: {"Service deleted", "service.delete"},
|
||||
|
||||
PeerServiceExposed: {"Peer exposed service", "service.peer.expose"},
|
||||
PeerServiceUnexposed: {"Peer unexposed service", "service.peer.unexpose"},
|
||||
PeerServiceExposeExpired: {"Peer exposed service expired", "service.peer.expose.expire"},
|
||||
|
||||
AccountPeerExposeEnabled: {"Account peer expose enabled", "account.setting.peer.expose.enable"},
|
||||
AccountPeerExposeDisabled: {"Account peer expose disabled", "account.setting.peer.expose.disable"},
|
||||
}
|
||||
|
||||
// StringCode returns a string code of the activity
|
||||
|
||||
@@ -249,15 +249,7 @@ func initDatabase(ctx context.Context, dataDir string) (*gorm.DB, error) {
|
||||
|
||||
switch storeEngine {
|
||||
case types.SqliteStoreEngine:
|
||||
dbFile := eventSinkDB
|
||||
if envFile, ok := os.LookupEnv("NB_ACTIVITY_EVENT_SQLITE_FILE"); ok && envFile != "" {
|
||||
dbFile = envFile
|
||||
}
|
||||
connStr := dbFile
|
||||
if !filepath.IsAbs(dbFile) {
|
||||
connStr = filepath.Join(dataDir, dbFile)
|
||||
}
|
||||
dialector = sqlite.Open(connStr)
|
||||
dialector = sqlite.Open(filepath.Join(dataDir, eventSinkDB))
|
||||
case types.PostgresStoreEngine:
|
||||
dsn, ok := os.LookupEnv(postgresDsnEnv)
|
||||
if !ok {
|
||||
|
||||
4
management/server/cache/store_test.go
vendored
4
management/server/cache/store_test.go
vendored
@@ -7,6 +7,8 @@ import (
|
||||
|
||||
"github.com/eko/gocache/lib/v4/store"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
|
||||
testcontainersredis "github.com/testcontainers/testcontainers-go/modules/redis"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/cache"
|
||||
@@ -48,7 +50,7 @@ func TestRedisStoreConnectionFailure(t *testing.T) {
|
||||
|
||||
func TestRedisStoreConnectionSuccess(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
redisContainer, err := testcontainersredis.Run(ctx, "redis:7")
|
||||
redisContainer, err := testcontainersredis.RunContainer(ctx, testcontainers.WithImage("redis:7"))
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't start redis container: %s", err)
|
||||
}
|
||||
|
||||
@@ -425,11 +425,6 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
var groupIDsToDelete []string
|
||||
var deletedGroups []*types.Group
|
||||
|
||||
extraSettings, err := am.settingsManager.GetExtraSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
for _, groupID := range groupIDs {
|
||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||
@@ -438,7 +433,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
continue
|
||||
}
|
||||
|
||||
if err = validateDeleteGroup(ctx, transaction, group, userID, extraSettings.FlowGroups); err != nil {
|
||||
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
@@ -626,7 +621,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string, flowGroups []string) error {
|
||||
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error {
|
||||
// disable a deleting integration group if the initiator is not an admin service user
|
||||
if group.Issued == types.GroupIssuedIntegration {
|
||||
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
@@ -646,10 +641,6 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
|
||||
return &GroupLinkError{"network resource", group.Resources[0].ID}
|
||||
}
|
||||
|
||||
if slices.Contains(flowGroups, group.ID) {
|
||||
return &GroupLinkError{"settings", "traffic event logging"}
|
||||
}
|
||||
|
||||
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"route", string(linkedRoute.NetID)}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user