Compare commits

...

47 Commits

Author SHA1 Message Date
Pascal Fischer
5585adce18 [management] add activity events for domains (#5548)
* add activity events for domains

* fix test

* update activity codes

* update activity codes
2026-03-09 19:04:04 +01:00
Pascal Fischer
f884299823 [proxy] refactor metrics and add usage logs (#5533)
* **New Features**
  * Access logs now include bytes_upload and bytes_download (API and schemas updated, fields required).
  * Certificate issuance duration is now recorded as a metric.

* **Refactor**
  * Metrics switched from Prometheus client to OpenTelemetry-backed meters; health endpoint now exposes OpenMetrics via OTLP exporter.

* **Tests**
  * Metric tests updated to use OpenTelemetry Prometheus exporter and MeterProvider.
2026-03-09 18:45:45 +01:00
Maycon Santos
15aa6bae1b [client] Fix exit node menu not refreshing on Windows (#5553)
* [client] Fix exit node menu not refreshing on Windows

TrayOpenedCh is not implemented in the systray library on Windows,
so exit nodes were never refreshed after the initial connect. Combined
with the management sync not having populated routes yet when the
Connected status fires, this caused the exit node menu to remain empty
permanently after disconnect/reconnect cycles.
Add a background poller on Windows that refreshes exit nodes while
connected, with fast initial polling to catch routes from management
sync followed by a steady 10s interval. On macOS/Linux, TrayOpenedCh
continues to handle refreshes on each tray open.
Also fix a data race on connectClient assignment in the server's connect()
method and add nil checks in CleanState/DeleteState to prevent panics
when connectClient is nil.

* Remove unused exitNodeIDs

* Remove unused exitNodeState struct
2026-03-09 18:39:11 +01:00
Pascal Fischer
11eb725ac8 [management] only count login request duration for successful logins (#5545) 2026-03-09 14:56:46 +01:00
Pascal Fischer
30c02ab78c [management] use the cache for the pkce state (#5516) 2026-03-09 12:23:06 +01:00
Zoltan Papp
3acd86e346 [client] "reset connection" error on wake from sleep (#5522)
Capture engine reference before actCancel() in cleanupConnection().

After actCancel(), the connectWithRetryRuns goroutine sets engine to nil,
causing connectClient.Stop() to skip shutdown. This allows the goroutine
to set ErrResetConnection on the shared state after Down() clears it,
causing the next Up() to fail.
2026-03-09 10:25:51 +01:00
Pascal Fischer
5c20f13c48 [management] fix domain uniqueness (#5529) 2026-03-07 10:46:37 +01:00
Pascal Fischer
e6587b071d [management] use realip for proxy registration (#5525) 2026-03-06 16:11:44 +01:00
Maycon Santos
85451ab4cd [management] Add stable domain resolution for combined server (#5515)
The combined server was using the hostname from exposedAddress for both
singleAccountModeDomain and dnsDomain, causing fresh installs to get
the wrong domain and existing installs to break if the config changed.
 Add resolveDomains() to BaseServer that reads domain from the store:
  - Fresh install (0 accounts): uses "netbird.selfhosted" default
  - Existing install: reads persisted domain from the account in DB
  - Store errors: falls back to default safely

The combined server opts in via AutoResolveDomains flag, while the
 standalone management server is unaffected.
2026-03-06 08:43:46 +01:00
Pascal Fischer
a7f3ba03eb [management] aggregate grpc metrics by accountID (#5486) 2026-03-05 22:10:45 +01:00
Maycon Santos
4f0a3a77ad [management] Avoid breaking single acc mode when switching domains (#5511)
* **Bug Fixes**
  * Fixed domain configuration handling in single account mode to properly retrieve and apply domain settings from account data.
  * Improved error handling when account data is unavailable with fallback to configured default domain.

* **Tests**
  * Added comprehensive test coverage for single account mode domain configuration scenarios, including edge cases for missing or unavailable account data.
2026-03-05 14:30:31 +01:00
Maycon Santos
44655ca9b5 [misc] add PR title validation workflow (#5503) 2026-03-05 11:43:18 +01:00
Viktor Liu
e601278117 [management,proxy] Add per-target options to reverse proxy (#5501) 2026-03-05 10:03:26 +01:00
Maycon Santos
8e7b016be2 [management] Replace in-memory expose tracker with SQL-backed operations (#5494)
The expose tracker used sync.Map for in-memory TTL tracking of active expose sessions, which broke and lost all sessions on restart.

Replace with SQL-backed operations that reuse the existing meta_last_renewed_at column:

- Add store methods: RenewEphemeralService, GetExpiredEphemeralServices, CountEphemeralServicesByPeer, EphemeralServiceExists
- Move duplicate/limit checks inside a transaction with row-level locking (SELECT ... FOR UPDATE) to prevent concurrent bypass
- Reaper re-checks expiry under row lock to avoid deleting a just-renewed service and prevent duplicate event emission 
- Add composite index on (source, source_peer) for efficient queries
- Batch-limit and column-select the reaper query to avoid DB/GC spikes
- Filter out malformed rows with empty source_peer
2026-03-04 18:15:13 +01:00
Maycon Santos
9e01ea7aae [misc] Add ISSUE_TEMPLATE configuration file (#5500)
Add issue template config file  with support and troubleshooting links
2026-03-04 14:30:54 +01:00
hbzhost
cfc7ec8bb9 [client] Fix SSH JWT auth failure with Azure Entra ID iat backdating (#5471)
Increase DefaultJWTMaxTokenAge from 5 to 10 minutes to accommodate
identity providers like Azure Entra ID that backdate the iat claim
by up to 5 minutes, causing tokens to be immediately rejected.

Fixes #5449

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-04 14:11:14 +01:00
Misha Bragin
b3bbc0e5c6 Fix embedded IdP metrics to count local and generic OIDC users (#5498) 2026-03-04 12:34:11 +02:00
Pascal Fischer
d7c8e37ff4 [management] Store connected proxies in DB (#5472)
Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
2026-03-03 18:39:46 +01:00
Zoltan Papp
05b66e73bc [client] Fix deadlock in route peer status watcher (#5489)
Wrap peerStateUpdate send in a nested select to prevent goroutine
blocking when the consumer has exited, which could fill the
subscription buffer and deadlock the Status mutex.
2026-03-03 13:50:46 +01:00
Jeremie Deray
01ceedac89 [client] Fix profile config directory permissions (#5457)
* fix user profile dir perm

* fix fileExists

* revert return var change

* fix anti-pattern
2026-03-03 13:48:51 +01:00
Misha Bragin
403babd433 [self-hosted] specify sql file location of auth, activity and main store (#5487) 2026-03-03 12:53:16 +02:00
Maycon Santos
47133031e5 [client] fix: client/Dockerfile to reduce vulnerabilities (#5217)
Co-authored-by: snyk-bot <snyk-bot@snyk.io>
2026-03-03 08:44:08 +01:00
Pascal Fischer
82da606886 [management] Add explicit target delete on service removal (#5420) 2026-03-02 18:25:44 +01:00
Viktor Liu
bbe5ae2145 [client] Flush buffer immediately to support gprc (#5469) 2026-03-02 15:17:08 +01:00
Viktor Liu
0b21498b39 [client] Fix close of closed channel panic in ConnectClient retry loop (#5470) 2026-03-02 10:07:53 +01:00
Viktor Liu
0ca59535f1 [management] Add reverse proxy services REST client (#5454) 2026-02-28 13:04:58 +08:00
Misha Bragin
59c77d0658 [self-hosted] support embedded IDP postgres db (#5443)
* Add postgres config for embedded idp

Entire-Checkpoint: 9ace190c1067

* Rename idpStore to authStore

Entire-Checkpoint: 73a896c79614

* Fix review notes

Entire-Checkpoint: 6556783c0df3

* Don't accept pq port = 0

Entire-Checkpoint: 80d45e37782f

* Optimize configs

Entire-Checkpoint: 80d45e37782f

* Fix lint issues

Entire-Checkpoint: 3eec968003d1

* Fail fast on combined postgres config

Entire-Checkpoint: b17839d3d8c6

* Simplify management config method

Entire-Checkpoint: 0f083effa20e
2026-02-27 14:52:54 +01:00
shuuri-labs
333e045099 Lower socket auto-discovery log from Info to Debug (#5463)
The discovery message was printing on every CLI invocation, which is
noisy for users on distros using the systemd template.
2026-02-26 17:51:38 +01:00
Zoltan Papp
c2c4d9d336 [client] Fix Server mutex held across waitForUp in Up() (#5460)
Up() acquired s.mutex with a deferred unlock, then called waitForUp()
while still holding the lock. waitForUp() blocks for up to 50 seconds
waiting on clientRunningChan/clientGiveUpChan, starving all concurrent
gRPC calls that require the same mutex (Status, ListProfiles, etc.).

Replace the deferred unlock with explicit s.mutex.Unlock() on every
early-return path and immediately before waitForUp(), matching the
pattern already used by the clientRunning==true branch.
2026-02-26 16:47:02 +01:00
Bethuel Mmbaga
9a6a72e88e [management] Fix user update permission validation (#5441) 2026-02-24 22:47:41 +03:00
Bethuel Mmbaga
afe6d9fca4 [management] Prevent deletion of groups linked to flow groups (#5439) 2026-02-24 21:19:43 +03:00
shuuri-labs
ef82905526 [client] Add non default socket file discovery (#5425)
- Automatic Unix daemon address discovery: if the default socket is missing, the client can find and use a single available socket.
- Client startup now resolves daemon addresses more robustly while preserving non-Unix behavior.
2026-02-24 17:02:06 +01:00
Zoltan Papp
d18747e846 [client] Exclude Flow domain from caching to prevent TLS failures (#5433)
* Exclude Flow domain from caching to prevent TLS failures due to stale records.

* Fix test
2026-02-24 16:48:38 +01:00
Maycon Santos
f341d69314 [management] Add custom domain counts and service metrics to self-hosted metrics (#5414) 2026-02-24 15:21:14 +01:00
Maycon Santos
327142837c [management] Refactor expose feature: move business logic from gRPC to manager (#5435)
Consolidate all expose business logic (validation, permission checks, TTL tracking, reaping) into the manager layer, making the gRPC layer a pure transport adapter that only handles proto conversion and authentication.

- Add ExposeServiceRequest/ExposeServiceResponse domain types with validation in the reverseproxy package
- Move expose tracker (TTL tracking, reaping, per-peer limits) from gRPC server into manager/expose_tracker.go
- Internalize tracking in CreateServiceFromPeer, RenewServiceFromPeer, and new StopServiceFromPeer so callers don't manage tracker state
- Untrack ephemeral services in DeleteService/DeleteAllServices to keep tracker in sync when services are deleted via API
- Simplify gRPC expose handlers to parse, auth, convert, delegate
- Remove tracker methods from Manager interface (internal detail)
2026-02-24 15:09:30 +01:00
Zoltan Papp
f8c0321aee [client] Simplify DNS logging by removing domain list from log output (#5396) 2026-02-24 10:35:45 +01:00
Zoltan Papp
89115ff76a [client] skip UAPI listener in netstack mode (#5397)
In netstack (proxy) mode, the process lacks permission to create
/var/run/wireguard, making the UAPI listener unnecessary and causing
a misleading error log. Introduce NewUSPConfigurerNoUAPI and use it
for the netstack device to avoid attempting to open the UAPI socket
entirely. Also consolidate UAPI error logging to a single call site.
2026-02-24 10:35:23 +01:00
Maycon Santos
63c83aa8d2 [client,management] Feature/client service expose (#5411)
CLI: new expose command to publish a local port with flags for PIN, password, user groups, custom domain, name prefix and protocol (HTTP default).
Management/API: create/renew/stop expose sessions (streamed status), automatic naming/domain, TTL renewals, background expiration, new management RPCs and client methods.
UI/API: account settings now include peer_expose_enabled and peer_expose_groups; new activity codes for peer expose events.
2026-02-24 10:02:16 +01:00
Zoltan Papp
37f025c966 Fix a race condition where a concurrent user-issued Up or Down command (#5418)
could interleave with a sleep/wake event causing out-of-order state
transitions. The mutex now covers the full duration of each handler
including the status check, the Up/Down call, and the flag update.

Note: if Up or Down commands are triggered in parallel with sleep/wake
events, the overall ordering of up/down/sleep/wake operations is still
not guaranteed beyond what the mutex provides within the handler itself.
2026-02-24 10:00:33 +01:00
Zoltan Papp
4a54f0d670 [Client] Remove connection semaphore (#5419)
* [Client] Remove connection semaphore

Remove the semaphore and the initial random sleep time (300ms) from the connectivity logic to speed up the initial connection time.

Note: Implement limiter logic that can prioritize router peers and keep the fast connection option for the first few peers.

* Remove unused function
2026-02-23 20:58:53 +01:00
Zoltan Papp
98890a29e3 [client] fix busy-loop in network monitor routing socket on macOS/BSD (#5424)
* [client] fix busy-loop in network monitor routing socket on macOS/BSD

After system wakeup, the AF_ROUTE socket created by Go's unix.Socket()
is non-blocking, causing unix.Read to return EAGAIN immediately and spin
at 100% CPU filling the log with thousands of warnings per second.

Replace the tight read loop with a unix.Select call that blocks until
the fd is readable, checking ctx cancellation on each 1-second timeout.
Fatal errors (EBADF, EINVAL) now return an error instead of looping.

* [client] add fd range validation in waitReadable to prevent out-of-bound errors
2026-02-23 20:58:27 +01:00
Pascal Fischer
9d123ec059 [proxy] add pre-shared key support (#5377) 2026-02-23 16:31:29 +01:00
Pascal Fischer
5d171f181a [proxy] Send proxy updates on account delete (#5375) 2026-02-23 16:08:28 +01:00
Vlad
22f878b3b7 [management] network map components assembling (#5193) 2026-02-23 15:34:35 +01:00
Misha Bragin
44ef1a18dd [self-hosted] add Embedded IdP metrics (#5407) 2026-02-22 11:58:35 +02:00
Misha Bragin
2b98dc4e52 [self-hosted] Support activity store engine in the combined server (#5406) 2026-02-22 11:58:17 +02:00
Zoltan Papp
2a26cb4567 [client] stop upstream retry loop immediately on context cancellation (#5403)
stop upstream retry loop immediately on context cancellation
2026-02-20 14:44:14 +01:00
156 changed files with 17530 additions and 3858 deletions

14
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@@ -0,0 +1,14 @@
blank_issues_enabled: true
contact_links:
- name: Community Support
url: https://forum.netbird.io/
about: Community support forum
- name: Cloud Support
url: https://docs.netbird.io/help/report-bug-issues
about: Contact us for support
- name: Client/Connection Troubleshooting
url: https://docs.netbird.io/help/troubleshooting-client
about: See our client troubleshooting guide for help addressing common issues
- name: Self-host Troubleshooting
url: https://docs.netbird.io/selfhosted/troubleshooting
about: See our self-host troubleshooting guide for help addressing common issues

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te
skip: go.mod,go.sum,**/proxy/web/**
golangci:
strategy:

51
.github/workflows/pr-title-check.yml vendored Normal file
View File

@@ -0,0 +1,51 @@
name: PR Title Check
on:
pull_request:
types: [opened, edited, synchronize, reopened]
jobs:
check-title:
runs-on: ubuntu-latest
steps:
- name: Validate PR title prefix
uses: actions/github-script@v7
with:
script: |
const title = context.payload.pull_request.title;
const allowedTags = [
'management',
'client',
'signal',
'proxy',
'relay',
'misc',
'infrastructure',
'self-hosted',
'doc',
];
const pattern = /^\[([^\]]+)\]\s+.+/;
const match = title.match(pattern);
if (!match) {
core.setFailed(
`PR title must start with a tag in brackets.\n` +
`Example: [client] fix something\n` +
`Allowed tags: ${allowedTags.join(', ')}`
);
return;
}
const tags = match[1].split(',').map(t => t.trim().toLowerCase());
const invalid = tags.filter(t => !allowedTags.includes(t));
if (invalid.length > 0) {
core.setFailed(
`Invalid tag(s): ${invalid.join(', ')}\n` +
`Allowed tags: ${allowedTags.join(', ')}`
);
return;
}
console.log(`Valid PR title tags: [${tags.join(', ')}]`);

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.23.2
FROM alpine:3.23.3
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \
bash \

194
client/cmd/expose.go Normal file
View File

@@ -0,0 +1,194 @@
package cmd
import (
"context"
"errors"
"fmt"
"io"
"os"
"os/signal"
"regexp"
"strconv"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/util"
)
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
var (
exposePin string
exposePassword string
exposeUserGroups []string
exposeDomain string
exposeNamePrefix string
exposeProtocol string
)
var exposeCmd = &cobra.Command{
Use: "expose <port>",
Short: "Expose a local port via the NetBird reverse proxy",
Args: cobra.ExactArgs(1),
Example: "netbird expose --with-password safe-pass 8080",
RunE: exposeFn,
}
func init() {
exposeCmd.Flags().StringVar(&exposePin, "with-pin", "", "Protect the exposed service with a 6-digit PIN (e.g. --with-pin 123456)")
exposeCmd.Flags().StringVar(&exposePassword, "with-password", "", "Protect the exposed service with a password (e.g. --with-password my-secret)")
exposeCmd.Flags().StringSliceVar(&exposeUserGroups, "with-user-groups", nil, "Restrict access to specific user groups with SSO (e.g. --with-user-groups devops,Backend)")
exposeCmd.Flags().StringVar(&exposeDomain, "with-custom-domain", "", "Custom domain for the exposed service, must be configured to your account (e.g. --with-custom-domain myapp.example.com)")
exposeCmd.Flags().StringVar(&exposeNamePrefix, "with-name-prefix", "", "Prefix for the generated service name (e.g. --with-name-prefix my-app)")
exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use, http/https is supported (e.g. --protocol http)")
}
func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) {
port, err := strconv.ParseUint(portStr, 10, 32)
if err != nil {
return 0, fmt.Errorf("invalid port number: %s", portStr)
}
if port == 0 || port > 65535 {
return 0, fmt.Errorf("invalid port number: must be between 1 and 65535")
}
if !isProtocolValid(exposeProtocol) {
return 0, fmt.Errorf("unsupported protocol %q: only 'http' or 'https' are supported", exposeProtocol)
}
if exposePin != "" && !pinRegexp.MatchString(exposePin) {
return 0, fmt.Errorf("invalid pin: must be exactly 6 digits")
}
if cmd.Flags().Changed("with-password") && exposePassword == "" {
return 0, fmt.Errorf("password cannot be empty")
}
if cmd.Flags().Changed("with-user-groups") && len(exposeUserGroups) == 0 {
return 0, fmt.Errorf("user groups cannot be empty")
}
return port, nil
}
func isProtocolValid(exposeProtocol string) bool {
return strings.ToLower(exposeProtocol) == "http" || strings.ToLower(exposeProtocol) == "https"
}
func exposeFn(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
log.Errorf("failed initializing log %v", err)
return err
}
cmd.Root().SilenceUsage = false
port, err := validateExposeFlags(cmd, args[0])
if err != nil {
return err
}
cmd.Root().SilenceUsage = true
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigCh
cancel()
}()
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("connect to daemon: %w", err)
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("failed to close daemon connection: %v", err)
}
}()
client := proto.NewDaemonServiceClient(conn)
protocol, err := toExposeProtocol(exposeProtocol)
if err != nil {
return err
}
stream, err := client.ExposeService(ctx, &proto.ExposeServiceRequest{
Port: uint32(port),
Protocol: protocol,
Pin: exposePin,
Password: exposePassword,
UserGroups: exposeUserGroups,
Domain: exposeDomain,
NamePrefix: exposeNamePrefix,
})
if err != nil {
return fmt.Errorf("expose service: %w", err)
}
if err := handleExposeReady(cmd, stream, port); err != nil {
return err
}
return waitForExposeEvents(cmd, ctx, stream)
}
func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
switch strings.ToLower(exposeProtocol) {
case "http":
return proto.ExposeProtocol_EXPOSE_HTTP, nil
case "https":
return proto.ExposeProtocol_EXPOSE_HTTPS, nil
default:
return 0, fmt.Errorf("unsupported protocol %q: only 'http' or 'https' are supported", exposeProtocol)
}
}
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
event, err := stream.Recv()
if err != nil {
return fmt.Errorf("receive expose event: %w", err)
}
switch e := event.Event.(type) {
case *proto.ExposeServiceEvent_Ready:
cmd.Println("Service exposed successfully!")
cmd.Printf(" Name: %s\n", e.Ready.ServiceName)
cmd.Printf(" URL: %s\n", e.Ready.ServiceUrl)
cmd.Printf(" Domain: %s\n", e.Ready.Domain)
cmd.Printf(" Protocol: %s\n", exposeProtocol)
cmd.Printf(" Port: %d\n", port)
cmd.Println()
cmd.Println("Press Ctrl+C to stop exposing.")
return nil
default:
return fmt.Errorf("unexpected expose event: %T", event.Event)
}
}
func waitForExposeEvents(cmd *cobra.Command, ctx context.Context, stream proto.DaemonService_ExposeServiceClient) error {
for {
_, err := stream.Recv()
if err != nil {
if ctx.Err() != nil {
cmd.Println("\nService stopped.")
//nolint:nilerr
return nil
}
if errors.Is(err, io.EOF) {
return fmt.Errorf("connection to daemon closed unexpectedly")
}
return fmt.Errorf("stream error: %w", err)
}
}
}

View File

@@ -22,6 +22,7 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
@@ -80,6 +81,15 @@ var (
Short: "",
Long: "",
SilenceUsage: true,
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(cmd.Root())
// Don't resolve for service commands — they create the socket, not connect to it.
if !isServiceCmd(cmd) {
daemonAddr = daddr.ResolveUnixDaemonAddr(daemonAddr)
}
return nil
},
}
)
@@ -144,6 +154,7 @@ func init() {
rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd)
rootCmd.AddCommand(profileCmd)
rootCmd.AddCommand(exposeCmd)
networksCMD.AddCommand(routesListCmd)
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
@@ -385,7 +396,6 @@ func migrateToNetbird(oldPath, newPath string) bool {
}
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
@@ -398,3 +408,13 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
return conn, nil
}
// isServiceCmd returns true if cmd is the "service" command or a child of it.
func isServiceCmd(cmd *cobra.Command) bool {
for c := cmd; c != nil; c = c.Parent() {
if c.Name() == "service" {
return true
}
}
return false
}

View File

@@ -5,20 +5,18 @@ package configurer
import (
"net"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/ipc"
)
func openUAPI(deviceName string) (net.Listener, error) {
uapiSock, err := ipc.UAPIOpen(deviceName)
if err != nil {
log.Errorf("failed to open uapi socket: %v", err)
return nil, err
}
listener, err := ipc.UAPIListen(deviceName, uapiSock)
if err != nil {
log.Errorf("failed to listen on uapi socket: %v", err)
_ = uapiSock.Close()
return nil, err
}

View File

@@ -54,6 +54,14 @@ func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder
return wgCfg
}
func NewUSPConfigurerNoUAPI(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
return &WGUSPConfigurer{
device: device,
deviceName: deviceName,
activityRecorder: activityRecorder,
}
}
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
log.Debugf("adding Wireguard private key")
key, err := wgtypes.ParseKey(privateKey)

View File

@@ -79,7 +79,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
device.NewLogger(wgLogLevel(), "[netbird] "),
)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurerNoUAPI(t.device, t.name, t.bind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
if cErr := tunIface.Close(); cErr != nil {

View File

@@ -331,8 +331,11 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
state.Set(StatusConnected)
if runningChan != nil {
close(runningChan)
runningChan = nil
select {
case <-runningChan:
default:
close(runningChan)
}
}
<-engineCtx.Done()

View File

@@ -0,0 +1,60 @@
//go:build !windows && !ios && !android
package daemonaddr
import (
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
)
var scanDir = "/var/run/netbird"
// setScanDir overrides the scan directory (used by tests).
func setScanDir(dir string) {
scanDir = dir
}
// ResolveUnixDaemonAddr checks whether the default Unix socket exists and, if not,
// scans /var/run/netbird/ for a single .sock file to use instead. This handles the
// mismatch between the netbird@.service template (which places the socket under
// /var/run/netbird/<instance>.sock) and the CLI default (/var/run/netbird.sock).
func ResolveUnixDaemonAddr(addr string) string {
if !strings.HasPrefix(addr, "unix://") {
return addr
}
sockPath := strings.TrimPrefix(addr, "unix://")
if _, err := os.Stat(sockPath); err == nil {
return addr
}
entries, err := os.ReadDir(scanDir)
if err != nil {
return addr
}
var found []string
for _, e := range entries {
if e.IsDir() {
continue
}
if strings.HasSuffix(e.Name(), ".sock") {
found = append(found, filepath.Join(scanDir, e.Name()))
}
}
switch len(found) {
case 1:
resolved := "unix://" + found[0]
log.Debugf("Default daemon socket not found, using discovered socket: %s", resolved)
return resolved
case 0:
return addr
default:
log.Warnf("Default daemon socket not found and multiple sockets discovered in %s; pass --daemon-addr explicitly", scanDir)
return addr
}
}

View File

@@ -0,0 +1,8 @@
//go:build windows || ios || android
package daemonaddr
// ResolveUnixDaemonAddr is a no-op on platforms that don't use Unix sockets.
func ResolveUnixDaemonAddr(addr string) string {
return addr
}

View File

@@ -0,0 +1,121 @@
//go:build !windows && !ios && !android
package daemonaddr
import (
"os"
"path/filepath"
"testing"
)
// createSockFile creates a regular file with a .sock extension.
// ResolveUnixDaemonAddr uses os.Stat (not net.Dial), so a regular file is
// sufficient and avoids Unix socket path-length limits on macOS.
func createSockFile(t *testing.T, path string) {
t.Helper()
if err := os.WriteFile(path, nil, 0o600); err != nil {
t.Fatalf("failed to create test sock file at %s: %v", path, err)
}
}
func TestResolveUnixDaemonAddr_DefaultExists(t *testing.T) {
tmp := t.TempDir()
sock := filepath.Join(tmp, "netbird.sock")
createSockFile(t, sock)
addr := "unix://" + sock
got := ResolveUnixDaemonAddr(addr)
if got != addr {
t.Errorf("expected %s, got %s", addr, got)
}
}
func TestResolveUnixDaemonAddr_SingleDiscovered(t *testing.T) {
tmp := t.TempDir()
// Default socket does not exist
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
// Create a scan dir with one socket
sd := filepath.Join(tmp, "netbird")
if err := os.MkdirAll(sd, 0o755); err != nil {
t.Fatal(err)
}
instanceSock := filepath.Join(sd, "main.sock")
createSockFile(t, instanceSock)
origScanDir := scanDir
setScanDir(sd)
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
expected := "unix://" + instanceSock
if got != expected {
t.Errorf("expected %s, got %s", expected, got)
}
}
func TestResolveUnixDaemonAddr_MultipleDiscovered(t *testing.T) {
tmp := t.TempDir()
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
sd := filepath.Join(tmp, "netbird")
if err := os.MkdirAll(sd, 0o755); err != nil {
t.Fatal(err)
}
createSockFile(t, filepath.Join(sd, "main.sock"))
createSockFile(t, filepath.Join(sd, "other.sock"))
origScanDir := scanDir
setScanDir(sd)
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
if got != defaultAddr {
t.Errorf("expected original %s, got %s", defaultAddr, got)
}
}
func TestResolveUnixDaemonAddr_NoSocketsFound(t *testing.T) {
tmp := t.TempDir()
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
sd := filepath.Join(tmp, "netbird")
if err := os.MkdirAll(sd, 0o755); err != nil {
t.Fatal(err)
}
origScanDir := scanDir
setScanDir(sd)
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
if got != defaultAddr {
t.Errorf("expected original %s, got %s", defaultAddr, got)
}
}
func TestResolveUnixDaemonAddr_NonUnixAddr(t *testing.T) {
addr := "tcp://127.0.0.1:41731"
got := ResolveUnixDaemonAddr(addr)
if got != addr {
t.Errorf("expected %s, got %s", addr, got)
}
}
func TestResolveUnixDaemonAddr_ScanDirMissing(t *testing.T) {
tmp := t.TempDir()
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
origScanDir := scanDir
setScanDir(filepath.Join(tmp, "nonexistent"))
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
if got != defaultAddr {
t.Errorf("expected original %s, got %s", defaultAddr, got)
}
}

View File

@@ -277,7 +277,7 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
}
}
log.Infof("added %d NRPT rules for %d domains. Domain list: %v", ruleIndex, len(domains), domains)
log.Infof("added %d NRPT rules for %d domains", ruleIndex, len(domains))
return ruleIndex, nil
}

View File

@@ -376,9 +376,9 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
}
}
if serverDomains.Flow != "" {
domains = append(domains, serverDomains.Flow)
}
// Flow receiver domain is intentionally excluded from caching.
// Cloud providers may rotate the IP behind this domain; a stale cached record
// causes TLS certificate verification failures on reconnect.
for _, stun := range serverDomains.Stuns {
if stun != "" {

View File

@@ -391,7 +391,8 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
}
assert.Len(t, resolver.GetCachedDomains(), 3)
// Update with partial ServerDomains (only flow domain - new type, should preserve all existing)
// Update with partial ServerDomains (only flow domain - flow is intentionally excluded from
// caching to prevent TLS failures from stale records, so all existing domains are preserved)
partialDomains := dnsconfig.ServerDomains{
Flow: "github.com",
}
@@ -400,10 +401,10 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
assert.Len(t, removedDomains, 0, "Should not remove any domains when adding new type")
assert.Len(t, removedDomains, 0, "Should not remove any domains when only flow domain is provided")
finalDomains := resolver.GetCachedDomains()
assert.Len(t, finalDomains, 4, "Should have all original domains plus new flow domain")
assert.Len(t, finalDomains, 3, "Flow domain is not cached; all original domains should be preserved")
domainStrings := make([]string, len(finalDomains))
for i, d := range finalDomains {
@@ -412,5 +413,5 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
assert.Contains(t, domainStrings, "example.org")
assert.Contains(t, domainStrings, "google.com")
assert.Contains(t, domainStrings, "cloudflare.com")
assert.Contains(t, domainStrings, "github.com")
assert.NotContains(t, domainStrings, "github.com")
}

View File

@@ -351,9 +351,13 @@ func (u *upstreamResolverBase) waitUntilResponse() {
return fmt.Errorf("upstream check call error")
}
err := backoff.Retry(operation, exponentialBackOff)
err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx))
if err != nil {
log.Warn(err)
if errors.Is(err, context.Canceled) {
log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString())
} else {
log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err)
}
return
}

View File

@@ -36,6 +36,7 @@ import (
"github.com/netbirdio/netbird/client/internal/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/expose"
"github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/netflow"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
@@ -53,13 +54,11 @@ import (
"github.com/netbirdio/netbird/client/internal/updatemanager"
"github.com/netbirdio/netbird/client/jobexec"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/shared/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
@@ -75,7 +74,6 @@ import (
const (
PeerConnectionTimeoutMax = 45000 // ms
PeerConnectionTimeoutMin = 30000 // ms
connInitLimit = 200
disableAutoUpdate = "disabled"
)
@@ -208,7 +206,6 @@ type Engine struct {
syncRespMux sync.RWMutex
persistSyncResponse bool
latestSyncResponse *mgmProto.SyncResponse
connSemaphore *semaphoregroup.SemaphoreGroup
flowManager nftypes.FlowManager
// auto-update
@@ -224,6 +221,8 @@ type Engine struct {
jobExecutor *jobexec.Executor
jobExecutorWG sync.WaitGroup
exposeManager *expose.Manager
}
// Peer is an instance of the Connection Peer
@@ -266,7 +265,6 @@ func NewEngine(
statusRecorder: statusRecorder,
stateManager: stateManager,
checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
jobExecutor: jobexec.NewExecutor(),
}
@@ -419,6 +417,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.cancel()
}
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
wgIface, err := e.newWgIface()
if err != nil {
@@ -801,7 +800,7 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
disabled := autoUpdateSettings.Version == disableAutoUpdate
// Stop and cleanup if disabled
// stop and cleanup if disabled
if e.updateManager != nil && disabled {
log.Infof("auto-update is disabled, stopping update manager")
e.updateManager.Stop()
@@ -1539,7 +1538,6 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
Semaphore: e.connSemaphore,
}
peerConn, err := peer.NewConn(config, serviceDependencies)
if err != nil {
@@ -1824,11 +1822,18 @@ func (e *Engine) GetRouteManager() routemanager.Manager {
return e.routeManager
}
// GetFirewallManager returns the firewall manager
// GetFirewallManager returns the firewall manager.
func (e *Engine) GetFirewallManager() firewallManager.Manager {
return e.firewall
}
// GetExposeManager returns the expose session manager.
func (e *Engine) GetExposeManager() *expose.Manager {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
return e.exposeManager
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {

View File

@@ -0,0 +1,95 @@
package expose
import (
"context"
"time"
mgm "github.com/netbirdio/netbird/shared/management/client"
log "github.com/sirupsen/logrus"
)
const renewTimeout = 10 * time.Second
// Response holds the response from exposing a service.
type Response struct {
ServiceName string
ServiceURL string
Domain string
}
type Request struct {
NamePrefix string
Domain string
Port uint16
Protocol int
Pin string
Password string
UserGroups []string
}
type ManagementClient interface {
CreateExpose(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error)
RenewExpose(ctx context.Context, domain string) error
StopExpose(ctx context.Context, domain string) error
}
// Manager handles expose session lifecycle via the management client.
type Manager struct {
mgmClient ManagementClient
ctx context.Context
}
// NewManager creates a new expose Manager using the given management client.
func NewManager(ctx context.Context, mgmClient ManagementClient) *Manager {
return &Manager{mgmClient: mgmClient, ctx: ctx}
}
// Expose creates a new expose session via the management server.
func (m *Manager) Expose(ctx context.Context, req Request) (*Response, error) {
log.Infof("exposing service on port %d", req.Port)
resp, err := m.mgmClient.CreateExpose(ctx, toClientExposeRequest(req))
if err != nil {
return nil, err
}
log.Infof("expose session created for %s", resp.Domain)
return fromClientExposeResponse(resp), nil
}
func (m *Manager) KeepAlive(ctx context.Context, domain string) error {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
defer m.stop(domain)
for {
select {
case <-ctx.Done():
log.Infof("context canceled, stopping keep alive for %s", domain)
return nil
case <-ticker.C:
if err := m.renew(ctx, domain); err != nil {
log.Errorf("renewing expose session for %s: %v", domain, err)
return err
}
}
}
}
// renew extends the TTL of an active expose session.
func (m *Manager) renew(ctx context.Context, domain string) error {
renewCtx, cancel := context.WithTimeout(ctx, renewTimeout)
defer cancel()
return m.mgmClient.RenewExpose(renewCtx, domain)
}
// stop terminates an active expose session.
func (m *Manager) stop(domain string) {
stopCtx, cancel := context.WithTimeout(m.ctx, renewTimeout)
defer cancel()
err := m.mgmClient.StopExpose(stopCtx, domain)
if err != nil {
log.Warnf("Failed stopping expose session for %s: %v", domain, err)
}
}

View File

@@ -0,0 +1,95 @@
package expose
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
daemonProto "github.com/netbirdio/netbird/client/proto"
mgm "github.com/netbirdio/netbird/shared/management/client"
)
func TestManager_Expose_Success(t *testing.T) {
mock := &mgm.MockClient{
CreateExposeFunc: func(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) {
return &mgm.ExposeResponse{
ServiceName: "my-service",
ServiceURL: "https://my-service.example.com",
Domain: "my-service.example.com",
}, nil
},
}
m := NewManager(context.Background(), mock)
result, err := m.Expose(context.Background(), Request{Port: 8080})
require.NoError(t, err)
assert.Equal(t, "my-service", result.ServiceName, "service name should match")
assert.Equal(t, "https://my-service.example.com", result.ServiceURL, "service URL should match")
assert.Equal(t, "my-service.example.com", result.Domain, "domain should match")
}
func TestManager_Expose_Error(t *testing.T) {
mock := &mgm.MockClient{
CreateExposeFunc: func(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) {
return nil, errors.New("permission denied")
},
}
m := NewManager(context.Background(), mock)
_, err := m.Expose(context.Background(), Request{Port: 8080})
require.Error(t, err)
assert.Contains(t, err.Error(), "permission denied", "error should propagate")
}
func TestManager_Renew_Success(t *testing.T) {
mock := &mgm.MockClient{
RenewExposeFunc: func(ctx context.Context, domain string) error {
assert.Equal(t, "my-service.example.com", domain, "domain should be passed through")
return nil
},
}
m := NewManager(context.Background(), mock)
err := m.renew(context.Background(), "my-service.example.com")
require.NoError(t, err)
}
func TestManager_Renew_Timeout(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
mock := &mgm.MockClient{
RenewExposeFunc: func(ctx context.Context, domain string) error {
return ctx.Err()
},
}
m := NewManager(ctx, mock)
err := m.renew(ctx, "my-service.example.com")
require.Error(t, err)
}
func TestNewRequest(t *testing.T) {
req := &daemonProto.ExposeServiceRequest{
Port: 8080,
Protocol: daemonProto.ExposeProtocol_EXPOSE_HTTPS,
Pin: "123456",
Password: "secret",
UserGroups: []string{"group1", "group2"},
Domain: "custom.example.com",
NamePrefix: "my-prefix",
}
exposeReq := NewRequest(req)
assert.Equal(t, uint16(8080), exposeReq.Port, "port should match")
assert.Equal(t, int(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match")
assert.Equal(t, "123456", exposeReq.Pin, "pin should match")
assert.Equal(t, "secret", exposeReq.Password, "password should match")
assert.Equal(t, []string{"group1", "group2"}, exposeReq.UserGroups, "user groups should match")
assert.Equal(t, "custom.example.com", exposeReq.Domain, "domain should match")
assert.Equal(t, "my-prefix", exposeReq.NamePrefix, "name prefix should match")
}

View File

@@ -0,0 +1,39 @@
package expose
import (
daemonProto "github.com/netbirdio/netbird/client/proto"
mgm "github.com/netbirdio/netbird/shared/management/client"
)
// NewRequest converts a daemon ExposeServiceRequest to a management ExposeServiceRequest.
func NewRequest(req *daemonProto.ExposeServiceRequest) *Request {
return &Request{
Port: uint16(req.Port),
Protocol: int(req.Protocol),
Pin: req.Pin,
Password: req.Password,
UserGroups: req.UserGroups,
Domain: req.Domain,
NamePrefix: req.NamePrefix,
}
}
func toClientExposeRequest(req Request) mgm.ExposeRequest {
return mgm.ExposeRequest{
NamePrefix: req.NamePrefix,
Domain: req.Domain,
Port: req.Port,
Protocol: req.Protocol,
Pin: req.Pin,
Password: req.Password,
UserGroups: req.UserGroups,
}
}
func fromClientExposeResponse(response *mgm.ExposeResponse) *Response {
return &Response{
ServiceName: response.ServiceName,
Domain: response.Domain,
ServiceURL: response.ServiceURL,
}
}

View File

@@ -22,51 +22,56 @@ func prepareFd() (int, error) {
func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Nexthop) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
buf := make([]byte, 2048)
n, err := unix.Read(fd, buf)
// Wait until fd is readable or context is cancelled, to avoid a busy-loop
// when the routing socket returns EAGAIN (e.g. immediately after wakeup).
if err := waitReadable(ctx, fd); err != nil {
return err
}
buf := make([]byte, 2048)
n, err := unix.Read(fd, buf)
if err != nil {
if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EINTR) {
continue
}
if errors.Is(err, unix.EBADF) || errors.Is(err, unix.EINVAL) {
return fmt.Errorf("routing socket closed: %w", err)
}
return fmt.Errorf("read routing socket: %w", err)
}
if n < unix.SizeofRtMsghdr {
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
continue
}
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
switch msg.Type {
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
if err != nil {
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
}
continue
}
if n < unix.SizeofRtMsghdr {
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
log.Debugf("Network monitor: error parsing routing message: %v", err)
continue
}
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
if route.Dst.Bits() != 0 {
continue
}
intf := "<nil>"
if route.Interface != nil {
intf = route.Interface.Name
}
switch msg.Type {
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
if err != nil {
log.Debugf("Network monitor: error parsing routing message: %v", err)
continue
}
if route.Dst.Bits() != 0 {
continue
}
intf := "<nil>"
if route.Interface != nil {
intf = route.Interface.Name
}
switch msg.Type {
case unix.RTM_ADD:
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
case unix.RTM_ADD:
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
return nil
case unix.RTM_DELETE:
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
return nil
case unix.RTM_DELETE:
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
return nil
}
}
}
}
@@ -90,3 +95,33 @@ func parseRouteMessage(buf []byte) (*systemops.Route, error) {
return systemops.MsgToRoute(msg)
}
// waitReadable blocks until fd has data to read, or ctx is cancelled.
func waitReadable(ctx context.Context, fd int) error {
var fdset unix.FdSet
if fd < 0 || fd/unix.NFDBITS >= len(fdset.Bits) {
return fmt.Errorf("fd %d out of range for FdSet", fd)
}
for {
if err := ctx.Err(); err != nil {
return err
}
fdset = unix.FdSet{}
fdset.Set(fd)
// Use a 1-second timeout so we can re-check ctx periodically.
tv := unix.Timeval{Sec: 1}
n, err := unix.Select(fd+1, &fdset, nil, nil, &tv)
if err != nil {
if errors.Is(err, unix.EINTR) {
continue
}
return fmt.Errorf("select on routing socket: %w", err)
}
if n > 0 {
return nil
}
// timeout — loop back and re-check ctx
}
}

View File

@@ -3,7 +3,6 @@ package peer
import (
"context"
"fmt"
"math/rand"
"net"
"net/netip"
"runtime"
@@ -25,7 +24,6 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
)
type ServiceDependencies struct {
@@ -34,7 +32,6 @@ type ServiceDependencies struct {
IFaceDiscover stdnet.ExternalIFaceDiscover
RelayManager *relayClient.Manager
SrWatcher *guard.SRWatcher
Semaphore *semaphoregroup.SemaphoreGroup
PeerConnDispatcher *dispatcher.ConnectionDispatcher
}
@@ -111,9 +108,8 @@ type Conn struct {
wgProxyRelay wgproxy.Proxy
handshaker *Handshaker
guard *guard.Guard
semaphore *semaphoregroup.SemaphoreGroup
wg sync.WaitGroup
guard *guard.Guard
wg sync.WaitGroup
// debug purpose
dumpState *stateDump
@@ -139,7 +135,6 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager,
srWatcher: services.SrWatcher,
semaphore: services.Semaphore,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: dumpState,
@@ -154,15 +149,10 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
// be used.
func (conn *Conn) Open(engineCtx context.Context) error {
if err := conn.semaphore.Add(engineCtx); err != nil {
return err
}
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.opened {
conn.semaphore.Done()
return nil
}
@@ -173,7 +163,6 @@ func (conn *Conn) Open(engineCtx context.Context) error {
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
if err != nil {
conn.semaphore.Done()
return err
}
conn.workerICE = workerICE
@@ -207,10 +196,6 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.wg.Add(1)
go func() {
defer conn.wg.Done()
conn.waitInitialRandomSleepTime(conn.ctx)
conn.semaphore.Done()
conn.guard.Start(conn.ctx, conn.onGuardEvent)
}()
conn.opened = true
@@ -670,19 +655,6 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd
}
}
func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
maxWait := 300
duration := time.Duration(rand.Intn(maxWait)) * time.Millisecond
timeout := time.NewTimer(duration)
defer timeout.Stop()
select {
case <-ctx.Done():
case <-timeout.C:
}
}
func (conn *Conn) isRelayed() bool {
switch conn.currentConnPriority {
case conntype.Relay, conntype.ICETurn:

View File

@@ -15,7 +15,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/util"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
)
var testDispatcher = dispatcher.NewConnectionDispatcher()
@@ -53,7 +52,6 @@ func TestConn_GetKey(t *testing.T) {
sd := ServiceDependencies{
SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher,
}
conn, err := NewConn(connConf, sd)
@@ -71,7 +69,6 @@ func TestConn_OnRemoteOffer(t *testing.T) {
sd := ServiceDependencies{
StatusRecorder: NewRecorder("https://mgm"),
SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher,
}
conn, err := NewConn(connConf, sd)
@@ -110,7 +107,6 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
sd := ServiceDependencies{
StatusRecorder: NewRecorder("https://mgm"),
SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher,
}
conn, err := NewConn(connConf, sd)

View File

@@ -198,7 +198,7 @@ func getConfigDirForUser(username string) (string, error) {
configDir := filepath.Join(DefaultConfigPathDir, username)
if _, err := os.Stat(configDir); os.IsNotExist(err) {
if err := os.MkdirAll(configDir, 0600); err != nil {
if err := os.MkdirAll(configDir, 0700); err != nil {
return "", err
}
}
@@ -206,9 +206,15 @@ func getConfigDirForUser(username string) (string, error) {
return configDir, nil
}
func fileExists(path string) bool {
func fileExists(path string) (bool, error) {
_, err := os.Stat(path)
return !os.IsNotExist(err)
if err == nil {
return true, nil
}
if os.IsNotExist(err) {
return false, nil
}
return false, err
}
// createNewConfig creates a new config generating a new Wireguard key and saving to file
@@ -635,7 +641,11 @@ func isPreSharedKeyHidden(preSharedKey *string) bool {
// UpdateConfig update existing configuration according to input configuration and return with the configuration
func UpdateConfig(input ConfigInput) (*Config, error) {
if !fileExists(input.ConfigPath) {
configExists, err := fileExists(input.ConfigPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if !configExists {
return nil, fmt.Errorf("config file %s does not exist", input.ConfigPath)
}
@@ -644,7 +654,11 @@ func UpdateConfig(input ConfigInput) (*Config, error) {
// UpdateOrCreateConfig reads existing config or generates a new one
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if !fileExists(input.ConfigPath) {
configExists, err := fileExists(input.ConfigPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if !configExists {
log.Infof("generating new config %s", input.ConfigPath)
cfg, err := createNewConfig(input)
if err != nil {
@@ -657,7 +671,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if isPreSharedKeyHidden(input.PreSharedKey) {
input.PreSharedKey = nil
}
err := util.EnforcePermission(input.ConfigPath)
err = util.EnforcePermission(input.ConfigPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
}
@@ -784,7 +798,12 @@ func ReadConfig(configPath string) (*Config, error) {
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
func readConfig(configPath string, createIfMissing bool) (*Config, error) {
if fileExists(configPath) {
configExists, err := fileExists(configPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if configExists {
err := util.EnforcePermission(configPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
@@ -831,7 +850,11 @@ func DirectWriteOutConfig(path string, config *Config) error {
// DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes.
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if !fileExists(input.ConfigPath) {
configExists, err := fileExists(input.ConfigPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if !configExists {
log.Infof("generating new config %s", input.ConfigPath)
cfg, err := createNewConfig(input)
if err != nil {

View File

@@ -256,7 +256,11 @@ func (s *ServiceManager) AddProfile(profileName, username string) error {
}
profPath := filepath.Join(configDir, profileName+".json")
if fileExists(profPath) {
profileExists, err := fileExists(profPath)
if err != nil {
return fmt.Errorf("failed to check if profile exists: %w", err)
}
if profileExists {
return ErrProfileAlreadyExists
}
@@ -285,7 +289,11 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error {
return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName)
}
profPath := filepath.Join(configDir, profileName+".json")
if !fileExists(profPath) {
profileExists, err := fileExists(profPath)
if err != nil {
return fmt.Errorf("failed to check if profile exists: %w", err)
}
if !profileExists {
return ErrProfileNotFound
}

View File

@@ -20,7 +20,11 @@ func (pm *ProfileManager) GetProfileState(profileName string) (*ProfileState, er
}
stateFile := filepath.Join(configDir, profileName+".state.json")
if !fileExists(stateFile) {
stateFileExists, err := fileExists(stateFile)
if err != nil {
return nil, fmt.Errorf("failed to check if profile state file exists: %w", err)
}
if !stateFileExists {
return nil, errors.New("profile state file does not exist")
}

View File

@@ -263,8 +263,14 @@ func (w *Watcher) watchPeerStatusChanges(ctx context.Context, peerKey string, pe
case <-closer:
return
case routerStates := <-subscription.Events():
peerStateUpdate <- routerStates
log.Debugf("triggered route state update for Peer: %s", peerKey)
select {
case peerStateUpdate <- routerStates:
log.Debugf("triggered route state update for Peer: %s", peerKey)
case <-ctx.Done():
return
case <-closer:
return
}
}
}
}

View File

@@ -0,0 +1,80 @@
package handler
import (
"context"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
)
type Agent interface {
Up(ctx context.Context) error
Down(ctx context.Context) error
Status() (internal.StatusType, error)
}
type SleepHandler struct {
agent Agent
mu sync.Mutex
// sleepTriggeredDown indicates whether the sleep handler triggered the last client down, to avoid unnecessary up on wake
sleepTriggeredDown bool
}
func New(agent Agent) *SleepHandler {
return &SleepHandler{
agent: agent,
}
}
func (s *SleepHandler) HandleWakeUp(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
if !s.sleepTriggeredDown {
log.Info("skipping up because wasn't sleep down")
return nil
}
// avoid other wakeup runs if sleep didn't make the computer sleep
s.sleepTriggeredDown = false
log.Info("running up after wake up")
err := s.agent.Up(ctx)
if err != nil {
log.Errorf("running up failed: %v", err)
return err
}
log.Info("running up command executed successfully")
return nil
}
func (s *SleepHandler) HandleSleep(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
status, err := s.agent.Status()
if err != nil {
return err
}
if status != internal.StatusConnecting && status != internal.StatusConnected {
log.Infof("skipping setting the agent down because status is %s", status)
return nil
}
log.Info("running down after system started sleeping")
if err = s.agent.Down(ctx); err != nil {
log.Errorf("running down failed: %v", err)
return err
}
s.sleepTriggeredDown = true
log.Info("running down executed successfully")
return nil
}

View File

@@ -0,0 +1,153 @@
package handler
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
)
type mockAgent struct {
upErr error
downErr error
statusErr error
status internal.StatusType
upCalls int
}
func (m *mockAgent) Up(_ context.Context) error {
m.upCalls++
return m.upErr
}
func (m *mockAgent) Down(_ context.Context) error {
return m.downErr
}
func (m *mockAgent) Status() (internal.StatusType, error) {
return m.status, m.statusErr
}
func newHandler(status internal.StatusType) (*SleepHandler, *mockAgent) {
agent := &mockAgent{status: status}
return New(agent), agent
}
func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) {
h, agent := newHandler(internal.StatusIdle)
err := h.HandleWakeUp(context.Background())
require.NoError(t, err)
assert.Equal(t, 0, agent.upCalls, "Up should not be called when flag is false")
}
func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) {
h, _ := newHandler(internal.StatusIdle)
h.sleepTriggeredDown = true
// Even if Up fails, flag should be reset
_ = h.HandleWakeUp(context.Background())
assert.False(t, h.sleepTriggeredDown, "flag must be reset before calling Up")
}
func TestHandleWakeUp_CallsUpWhenFlagSet(t *testing.T) {
h, agent := newHandler(internal.StatusIdle)
h.sleepTriggeredDown = true
err := h.HandleWakeUp(context.Background())
require.NoError(t, err)
assert.Equal(t, 1, agent.upCalls)
assert.False(t, h.sleepTriggeredDown)
}
func TestHandleWakeUp_ReturnsErrorFromUp(t *testing.T) {
h, agent := newHandler(internal.StatusIdle)
h.sleepTriggeredDown = true
agent.upErr = errors.New("up failed")
err := h.HandleWakeUp(context.Background())
assert.ErrorIs(t, err, agent.upErr)
assert.False(t, h.sleepTriggeredDown, "flag should still be reset even when Up fails")
}
func TestHandleWakeUp_SecondCallIsNoOp(t *testing.T) {
h, agent := newHandler(internal.StatusIdle)
h.sleepTriggeredDown = true
_ = h.HandleWakeUp(context.Background())
err := h.HandleWakeUp(context.Background())
require.NoError(t, err)
assert.Equal(t, 1, agent.upCalls, "second wakeup should be no-op")
}
func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) {
tests := []struct {
name string
status internal.StatusType
}{
{"Idle", internal.StatusIdle},
{"NeedsLogin", internal.StatusNeedsLogin},
{"LoginFailed", internal.StatusLoginFailed},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h, _ := newHandler(tt.status)
err := h.HandleSleep(context.Background())
require.NoError(t, err)
assert.False(t, h.sleepTriggeredDown)
})
}
}
func TestHandleSleep_ProceedsForActiveStates(t *testing.T) {
tests := []struct {
name string
status internal.StatusType
}{
{"Connecting", internal.StatusConnecting},
{"Connected", internal.StatusConnected},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h, _ := newHandler(tt.status)
err := h.HandleSleep(context.Background())
require.NoError(t, err)
assert.True(t, h.sleepTriggeredDown)
})
}
}
func TestHandleSleep_ReturnsErrorFromStatus(t *testing.T) {
agent := &mockAgent{statusErr: errors.New("status error")}
h := New(agent)
err := h.HandleSleep(context.Background())
assert.ErrorIs(t, err, agent.statusErr)
assert.False(t, h.sleepTriggeredDown)
}
func TestHandleSleep_ReturnsErrorFromDown(t *testing.T) {
agent := &mockAgent{status: internal.StatusConnected, downErr: errors.New("down failed")}
h := New(agent)
err := h.HandleSleep(context.Background())
assert.ErrorIs(t, err, agent.downErr)
assert.False(t, h.sleepTriggeredDown, "flag should not be set when Down fails")
}

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.6
// protoc v6.32.1
// protoc v6.33.3
// source: daemon.proto
package proto
@@ -88,6 +88,58 @@ func (LogLevel) EnumDescriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{0}
}
type ExposeProtocol int32
const (
ExposeProtocol_EXPOSE_HTTP ExposeProtocol = 0
ExposeProtocol_EXPOSE_HTTPS ExposeProtocol = 1
ExposeProtocol_EXPOSE_TCP ExposeProtocol = 2
ExposeProtocol_EXPOSE_UDP ExposeProtocol = 3
)
// Enum value maps for ExposeProtocol.
var (
ExposeProtocol_name = map[int32]string{
0: "EXPOSE_HTTP",
1: "EXPOSE_HTTPS",
2: "EXPOSE_TCP",
3: "EXPOSE_UDP",
}
ExposeProtocol_value = map[string]int32{
"EXPOSE_HTTP": 0,
"EXPOSE_HTTPS": 1,
"EXPOSE_TCP": 2,
"EXPOSE_UDP": 3,
}
)
func (x ExposeProtocol) Enum() *ExposeProtocol {
p := new(ExposeProtocol)
*p = x
return p
}
func (x ExposeProtocol) String() string {
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
}
func (ExposeProtocol) Descriptor() protoreflect.EnumDescriptor {
return file_daemon_proto_enumTypes[1].Descriptor()
}
func (ExposeProtocol) Type() protoreflect.EnumType {
return &file_daemon_proto_enumTypes[1]
}
func (x ExposeProtocol) Number() protoreflect.EnumNumber {
return protoreflect.EnumNumber(x)
}
// Deprecated: Use ExposeProtocol.Descriptor instead.
func (ExposeProtocol) EnumDescriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{1}
}
// avoid collision with loglevel enum
type OSLifecycleRequest_CycleType int32
@@ -122,11 +174,11 @@ func (x OSLifecycleRequest_CycleType) String() string {
}
func (OSLifecycleRequest_CycleType) Descriptor() protoreflect.EnumDescriptor {
return file_daemon_proto_enumTypes[1].Descriptor()
return file_daemon_proto_enumTypes[2].Descriptor()
}
func (OSLifecycleRequest_CycleType) Type() protoreflect.EnumType {
return &file_daemon_proto_enumTypes[1]
return &file_daemon_proto_enumTypes[2]
}
func (x OSLifecycleRequest_CycleType) Number() protoreflect.EnumNumber {
@@ -174,11 +226,11 @@ func (x SystemEvent_Severity) String() string {
}
func (SystemEvent_Severity) Descriptor() protoreflect.EnumDescriptor {
return file_daemon_proto_enumTypes[2].Descriptor()
return file_daemon_proto_enumTypes[3].Descriptor()
}
func (SystemEvent_Severity) Type() protoreflect.EnumType {
return &file_daemon_proto_enumTypes[2]
return &file_daemon_proto_enumTypes[3]
}
func (x SystemEvent_Severity) Number() protoreflect.EnumNumber {
@@ -229,11 +281,11 @@ func (x SystemEvent_Category) String() string {
}
func (SystemEvent_Category) Descriptor() protoreflect.EnumDescriptor {
return file_daemon_proto_enumTypes[3].Descriptor()
return file_daemon_proto_enumTypes[4].Descriptor()
}
func (SystemEvent_Category) Type() protoreflect.EnumType {
return &file_daemon_proto_enumTypes[3]
return &file_daemon_proto_enumTypes[4]
}
func (x SystemEvent_Category) Number() protoreflect.EnumNumber {
@@ -5600,6 +5652,224 @@ func (x *InstallerResultResponse) GetErrorMsg() string {
return ""
}
type ExposeServiceRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
Port uint32 `protobuf:"varint,1,opt,name=port,proto3" json:"port,omitempty"`
Protocol ExposeProtocol `protobuf:"varint,2,opt,name=protocol,proto3,enum=daemon.ExposeProtocol" json:"protocol,omitempty"`
Pin string `protobuf:"bytes,3,opt,name=pin,proto3" json:"pin,omitempty"`
Password string `protobuf:"bytes,4,opt,name=password,proto3" json:"password,omitempty"`
UserGroups []string `protobuf:"bytes,5,rep,name=user_groups,json=userGroups,proto3" json:"user_groups,omitempty"`
Domain string `protobuf:"bytes,6,opt,name=domain,proto3" json:"domain,omitempty"`
NamePrefix string `protobuf:"bytes,7,opt,name=name_prefix,json=namePrefix,proto3" json:"name_prefix,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ExposeServiceRequest) Reset() {
*x = ExposeServiceRequest{}
mi := &file_daemon_proto_msgTypes[85]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ExposeServiceRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ExposeServiceRequest) ProtoMessage() {}
func (x *ExposeServiceRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[85]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ExposeServiceRequest.ProtoReflect.Descriptor instead.
func (*ExposeServiceRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{85}
}
func (x *ExposeServiceRequest) GetPort() uint32 {
if x != nil {
return x.Port
}
return 0
}
func (x *ExposeServiceRequest) GetProtocol() ExposeProtocol {
if x != nil {
return x.Protocol
}
return ExposeProtocol_EXPOSE_HTTP
}
func (x *ExposeServiceRequest) GetPin() string {
if x != nil {
return x.Pin
}
return ""
}
func (x *ExposeServiceRequest) GetPassword() string {
if x != nil {
return x.Password
}
return ""
}
func (x *ExposeServiceRequest) GetUserGroups() []string {
if x != nil {
return x.UserGroups
}
return nil
}
func (x *ExposeServiceRequest) GetDomain() string {
if x != nil {
return x.Domain
}
return ""
}
func (x *ExposeServiceRequest) GetNamePrefix() string {
if x != nil {
return x.NamePrefix
}
return ""
}
type ExposeServiceEvent struct {
state protoimpl.MessageState `protogen:"open.v1"`
// Types that are valid to be assigned to Event:
//
// *ExposeServiceEvent_Ready
Event isExposeServiceEvent_Event `protobuf_oneof:"event"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ExposeServiceEvent) Reset() {
*x = ExposeServiceEvent{}
mi := &file_daemon_proto_msgTypes[86]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ExposeServiceEvent) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ExposeServiceEvent) ProtoMessage() {}
func (x *ExposeServiceEvent) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[86]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ExposeServiceEvent.ProtoReflect.Descriptor instead.
func (*ExposeServiceEvent) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{86}
}
func (x *ExposeServiceEvent) GetEvent() isExposeServiceEvent_Event {
if x != nil {
return x.Event
}
return nil
}
func (x *ExposeServiceEvent) GetReady() *ExposeServiceReady {
if x != nil {
if x, ok := x.Event.(*ExposeServiceEvent_Ready); ok {
return x.Ready
}
}
return nil
}
type isExposeServiceEvent_Event interface {
isExposeServiceEvent_Event()
}
type ExposeServiceEvent_Ready struct {
Ready *ExposeServiceReady `protobuf:"bytes,1,opt,name=ready,proto3,oneof"`
}
func (*ExposeServiceEvent_Ready) isExposeServiceEvent_Event() {}
type ExposeServiceReady struct {
state protoimpl.MessageState `protogen:"open.v1"`
ServiceName string `protobuf:"bytes,1,opt,name=service_name,json=serviceName,proto3" json:"service_name,omitempty"`
ServiceUrl string `protobuf:"bytes,2,opt,name=service_url,json=serviceUrl,proto3" json:"service_url,omitempty"`
Domain string `protobuf:"bytes,3,opt,name=domain,proto3" json:"domain,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ExposeServiceReady) Reset() {
*x = ExposeServiceReady{}
mi := &file_daemon_proto_msgTypes[87]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ExposeServiceReady) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ExposeServiceReady) ProtoMessage() {}
func (x *ExposeServiceReady) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[87]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ExposeServiceReady.ProtoReflect.Descriptor instead.
func (*ExposeServiceReady) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{87}
}
func (x *ExposeServiceReady) GetServiceName() string {
if x != nil {
return x.ServiceName
}
return ""
}
func (x *ExposeServiceReady) GetServiceUrl() string {
if x != nil {
return x.ServiceUrl
}
return ""
}
func (x *ExposeServiceReady) GetDomain() string {
if x != nil {
return x.Domain
}
return ""
}
type PortInfo_Range struct {
state protoimpl.MessageState `protogen:"open.v1"`
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
@@ -5610,7 +5880,7 @@ type PortInfo_Range struct {
func (x *PortInfo_Range) Reset() {
*x = PortInfo_Range{}
mi := &file_daemon_proto_msgTypes[86]
mi := &file_daemon_proto_msgTypes[89]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -5622,7 +5892,7 @@ func (x *PortInfo_Range) String() string {
func (*PortInfo_Range) ProtoMessage() {}
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[86]
mi := &file_daemon_proto_msgTypes[89]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -6149,7 +6419,25 @@ const file_daemon_proto_rawDesc = "" +
"\x16InstallerResultRequest\"O\n" +
"\x17InstallerResultResponse\x12\x18\n" +
"\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" +
"\berrorMsg\x18\x02 \x01(\tR\berrorMsg*b\n" +
"\berrorMsg\x18\x02 \x01(\tR\berrorMsg\"\xe6\x01\n" +
"\x14ExposeServiceRequest\x12\x12\n" +
"\x04port\x18\x01 \x01(\rR\x04port\x122\n" +
"\bprotocol\x18\x02 \x01(\x0e2\x16.daemon.ExposeProtocolR\bprotocol\x12\x10\n" +
"\x03pin\x18\x03 \x01(\tR\x03pin\x12\x1a\n" +
"\bpassword\x18\x04 \x01(\tR\bpassword\x12\x1f\n" +
"\vuser_groups\x18\x05 \x03(\tR\n" +
"userGroups\x12\x16\n" +
"\x06domain\x18\x06 \x01(\tR\x06domain\x12\x1f\n" +
"\vname_prefix\x18\a \x01(\tR\n" +
"namePrefix\"Q\n" +
"\x12ExposeServiceEvent\x122\n" +
"\x05ready\x18\x01 \x01(\v2\x1a.daemon.ExposeServiceReadyH\x00R\x05readyB\a\n" +
"\x05event\"p\n" +
"\x12ExposeServiceReady\x12!\n" +
"\fservice_name\x18\x01 \x01(\tR\vserviceName\x12\x1f\n" +
"\vservice_url\x18\x02 \x01(\tR\n" +
"serviceUrl\x12\x16\n" +
"\x06domain\x18\x03 \x01(\tR\x06domain*b\n" +
"\bLogLevel\x12\v\n" +
"\aUNKNOWN\x10\x00\x12\t\n" +
"\x05PANIC\x10\x01\x12\t\n" +
@@ -6158,7 +6446,14 @@ const file_daemon_proto_rawDesc = "" +
"\x04WARN\x10\x04\x12\b\n" +
"\x04INFO\x10\x05\x12\t\n" +
"\x05DEBUG\x10\x06\x12\t\n" +
"\x05TRACE\x10\a2\xdd\x14\n" +
"\x05TRACE\x10\a*S\n" +
"\x0eExposeProtocol\x12\x0f\n" +
"\vEXPOSE_HTTP\x10\x00\x12\x10\n" +
"\fEXPOSE_HTTPS\x10\x01\x12\x0e\n" +
"\n" +
"EXPOSE_TCP\x10\x02\x12\x0e\n" +
"\n" +
"EXPOSE_UDP\x10\x032\xac\x15\n" +
"\rDaemonService\x126\n" +
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
@@ -6197,7 +6492,8 @@ const file_daemon_proto_rawDesc = "" +
"\x0fStartCPUProfile\x12\x1e.daemon.StartCPUProfileRequest\x1a\x1f.daemon.StartCPUProfileResponse\"\x00\x12Q\n" +
"\x0eStopCPUProfile\x12\x1d.daemon.StopCPUProfileRequest\x1a\x1e.daemon.StopCPUProfileResponse\"\x00\x12N\n" +
"\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" +
"\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00B\bZ\x06/protob\x06proto3"
"\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00\x12M\n" +
"\rExposeService\x12\x1c.daemon.ExposeServiceRequest\x1a\x1a.daemon.ExposeServiceEvent\"\x000\x01B\bZ\x06/protob\x06proto3"
var (
file_daemon_proto_rawDescOnce sync.Once
@@ -6211,214 +6507,222 @@ func file_daemon_proto_rawDescGZIP() []byte {
return file_daemon_proto_rawDescData
}
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 88)
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 5)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 91)
var file_daemon_proto_goTypes = []any{
(LogLevel)(0), // 0: daemon.LogLevel
(OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType
(SystemEvent_Severity)(0), // 2: daemon.SystemEvent.Severity
(SystemEvent_Category)(0), // 3: daemon.SystemEvent.Category
(*EmptyRequest)(nil), // 4: daemon.EmptyRequest
(*OSLifecycleRequest)(nil), // 5: daemon.OSLifecycleRequest
(*OSLifecycleResponse)(nil), // 6: daemon.OSLifecycleResponse
(*LoginRequest)(nil), // 7: daemon.LoginRequest
(*LoginResponse)(nil), // 8: daemon.LoginResponse
(*WaitSSOLoginRequest)(nil), // 9: daemon.WaitSSOLoginRequest
(*WaitSSOLoginResponse)(nil), // 10: daemon.WaitSSOLoginResponse
(*UpRequest)(nil), // 11: daemon.UpRequest
(*UpResponse)(nil), // 12: daemon.UpResponse
(*StatusRequest)(nil), // 13: daemon.StatusRequest
(*StatusResponse)(nil), // 14: daemon.StatusResponse
(*DownRequest)(nil), // 15: daemon.DownRequest
(*DownResponse)(nil), // 16: daemon.DownResponse
(*GetConfigRequest)(nil), // 17: daemon.GetConfigRequest
(*GetConfigResponse)(nil), // 18: daemon.GetConfigResponse
(*PeerState)(nil), // 19: daemon.PeerState
(*LocalPeerState)(nil), // 20: daemon.LocalPeerState
(*SignalState)(nil), // 21: daemon.SignalState
(*ManagementState)(nil), // 22: daemon.ManagementState
(*RelayState)(nil), // 23: daemon.RelayState
(*NSGroupState)(nil), // 24: daemon.NSGroupState
(*SSHSessionInfo)(nil), // 25: daemon.SSHSessionInfo
(*SSHServerState)(nil), // 26: daemon.SSHServerState
(*FullStatus)(nil), // 27: daemon.FullStatus
(*ListNetworksRequest)(nil), // 28: daemon.ListNetworksRequest
(*ListNetworksResponse)(nil), // 29: daemon.ListNetworksResponse
(*SelectNetworksRequest)(nil), // 30: daemon.SelectNetworksRequest
(*SelectNetworksResponse)(nil), // 31: daemon.SelectNetworksResponse
(*IPList)(nil), // 32: daemon.IPList
(*Network)(nil), // 33: daemon.Network
(*PortInfo)(nil), // 34: daemon.PortInfo
(*ForwardingRule)(nil), // 35: daemon.ForwardingRule
(*ForwardingRulesResponse)(nil), // 36: daemon.ForwardingRulesResponse
(*DebugBundleRequest)(nil), // 37: daemon.DebugBundleRequest
(*DebugBundleResponse)(nil), // 38: daemon.DebugBundleResponse
(*GetLogLevelRequest)(nil), // 39: daemon.GetLogLevelRequest
(*GetLogLevelResponse)(nil), // 40: daemon.GetLogLevelResponse
(*SetLogLevelRequest)(nil), // 41: daemon.SetLogLevelRequest
(*SetLogLevelResponse)(nil), // 42: daemon.SetLogLevelResponse
(*State)(nil), // 43: daemon.State
(*ListStatesRequest)(nil), // 44: daemon.ListStatesRequest
(*ListStatesResponse)(nil), // 45: daemon.ListStatesResponse
(*CleanStateRequest)(nil), // 46: daemon.CleanStateRequest
(*CleanStateResponse)(nil), // 47: daemon.CleanStateResponse
(*DeleteStateRequest)(nil), // 48: daemon.DeleteStateRequest
(*DeleteStateResponse)(nil), // 49: daemon.DeleteStateResponse
(*SetSyncResponsePersistenceRequest)(nil), // 50: daemon.SetSyncResponsePersistenceRequest
(*SetSyncResponsePersistenceResponse)(nil), // 51: daemon.SetSyncResponsePersistenceResponse
(*TCPFlags)(nil), // 52: daemon.TCPFlags
(*TracePacketRequest)(nil), // 53: daemon.TracePacketRequest
(*TraceStage)(nil), // 54: daemon.TraceStage
(*TracePacketResponse)(nil), // 55: daemon.TracePacketResponse
(*SubscribeRequest)(nil), // 56: daemon.SubscribeRequest
(*SystemEvent)(nil), // 57: daemon.SystemEvent
(*GetEventsRequest)(nil), // 58: daemon.GetEventsRequest
(*GetEventsResponse)(nil), // 59: daemon.GetEventsResponse
(*SwitchProfileRequest)(nil), // 60: daemon.SwitchProfileRequest
(*SwitchProfileResponse)(nil), // 61: daemon.SwitchProfileResponse
(*SetConfigRequest)(nil), // 62: daemon.SetConfigRequest
(*SetConfigResponse)(nil), // 63: daemon.SetConfigResponse
(*AddProfileRequest)(nil), // 64: daemon.AddProfileRequest
(*AddProfileResponse)(nil), // 65: daemon.AddProfileResponse
(*RemoveProfileRequest)(nil), // 66: daemon.RemoveProfileRequest
(*RemoveProfileResponse)(nil), // 67: daemon.RemoveProfileResponse
(*ListProfilesRequest)(nil), // 68: daemon.ListProfilesRequest
(*ListProfilesResponse)(nil), // 69: daemon.ListProfilesResponse
(*Profile)(nil), // 70: daemon.Profile
(*GetActiveProfileRequest)(nil), // 71: daemon.GetActiveProfileRequest
(*GetActiveProfileResponse)(nil), // 72: daemon.GetActiveProfileResponse
(*LogoutRequest)(nil), // 73: daemon.LogoutRequest
(*LogoutResponse)(nil), // 74: daemon.LogoutResponse
(*GetFeaturesRequest)(nil), // 75: daemon.GetFeaturesRequest
(*GetFeaturesResponse)(nil), // 76: daemon.GetFeaturesResponse
(*GetPeerSSHHostKeyRequest)(nil), // 77: daemon.GetPeerSSHHostKeyRequest
(*GetPeerSSHHostKeyResponse)(nil), // 78: daemon.GetPeerSSHHostKeyResponse
(*RequestJWTAuthRequest)(nil), // 79: daemon.RequestJWTAuthRequest
(*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse
(*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest
(*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse
(*StartCPUProfileRequest)(nil), // 83: daemon.StartCPUProfileRequest
(*StartCPUProfileResponse)(nil), // 84: daemon.StartCPUProfileResponse
(*StopCPUProfileRequest)(nil), // 85: daemon.StopCPUProfileRequest
(*StopCPUProfileResponse)(nil), // 86: daemon.StopCPUProfileResponse
(*InstallerResultRequest)(nil), // 87: daemon.InstallerResultRequest
(*InstallerResultResponse)(nil), // 88: daemon.InstallerResultResponse
nil, // 89: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 90: daemon.PortInfo.Range
nil, // 91: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 92: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 93: google.protobuf.Timestamp
(ExposeProtocol)(0), // 1: daemon.ExposeProtocol
(OSLifecycleRequest_CycleType)(0), // 2: daemon.OSLifecycleRequest.CycleType
(SystemEvent_Severity)(0), // 3: daemon.SystemEvent.Severity
(SystemEvent_Category)(0), // 4: daemon.SystemEvent.Category
(*EmptyRequest)(nil), // 5: daemon.EmptyRequest
(*OSLifecycleRequest)(nil), // 6: daemon.OSLifecycleRequest
(*OSLifecycleResponse)(nil), // 7: daemon.OSLifecycleResponse
(*LoginRequest)(nil), // 8: daemon.LoginRequest
(*LoginResponse)(nil), // 9: daemon.LoginResponse
(*WaitSSOLoginRequest)(nil), // 10: daemon.WaitSSOLoginRequest
(*WaitSSOLoginResponse)(nil), // 11: daemon.WaitSSOLoginResponse
(*UpRequest)(nil), // 12: daemon.UpRequest
(*UpResponse)(nil), // 13: daemon.UpResponse
(*StatusRequest)(nil), // 14: daemon.StatusRequest
(*StatusResponse)(nil), // 15: daemon.StatusResponse
(*DownRequest)(nil), // 16: daemon.DownRequest
(*DownResponse)(nil), // 17: daemon.DownResponse
(*GetConfigRequest)(nil), // 18: daemon.GetConfigRequest
(*GetConfigResponse)(nil), // 19: daemon.GetConfigResponse
(*PeerState)(nil), // 20: daemon.PeerState
(*LocalPeerState)(nil), // 21: daemon.LocalPeerState
(*SignalState)(nil), // 22: daemon.SignalState
(*ManagementState)(nil), // 23: daemon.ManagementState
(*RelayState)(nil), // 24: daemon.RelayState
(*NSGroupState)(nil), // 25: daemon.NSGroupState
(*SSHSessionInfo)(nil), // 26: daemon.SSHSessionInfo
(*SSHServerState)(nil), // 27: daemon.SSHServerState
(*FullStatus)(nil), // 28: daemon.FullStatus
(*ListNetworksRequest)(nil), // 29: daemon.ListNetworksRequest
(*ListNetworksResponse)(nil), // 30: daemon.ListNetworksResponse
(*SelectNetworksRequest)(nil), // 31: daemon.SelectNetworksRequest
(*SelectNetworksResponse)(nil), // 32: daemon.SelectNetworksResponse
(*IPList)(nil), // 33: daemon.IPList
(*Network)(nil), // 34: daemon.Network
(*PortInfo)(nil), // 35: daemon.PortInfo
(*ForwardingRule)(nil), // 36: daemon.ForwardingRule
(*ForwardingRulesResponse)(nil), // 37: daemon.ForwardingRulesResponse
(*DebugBundleRequest)(nil), // 38: daemon.DebugBundleRequest
(*DebugBundleResponse)(nil), // 39: daemon.DebugBundleResponse
(*GetLogLevelRequest)(nil), // 40: daemon.GetLogLevelRequest
(*GetLogLevelResponse)(nil), // 41: daemon.GetLogLevelResponse
(*SetLogLevelRequest)(nil), // 42: daemon.SetLogLevelRequest
(*SetLogLevelResponse)(nil), // 43: daemon.SetLogLevelResponse
(*State)(nil), // 44: daemon.State
(*ListStatesRequest)(nil), // 45: daemon.ListStatesRequest
(*ListStatesResponse)(nil), // 46: daemon.ListStatesResponse
(*CleanStateRequest)(nil), // 47: daemon.CleanStateRequest
(*CleanStateResponse)(nil), // 48: daemon.CleanStateResponse
(*DeleteStateRequest)(nil), // 49: daemon.DeleteStateRequest
(*DeleteStateResponse)(nil), // 50: daemon.DeleteStateResponse
(*SetSyncResponsePersistenceRequest)(nil), // 51: daemon.SetSyncResponsePersistenceRequest
(*SetSyncResponsePersistenceResponse)(nil), // 52: daemon.SetSyncResponsePersistenceResponse
(*TCPFlags)(nil), // 53: daemon.TCPFlags
(*TracePacketRequest)(nil), // 54: daemon.TracePacketRequest
(*TraceStage)(nil), // 55: daemon.TraceStage
(*TracePacketResponse)(nil), // 56: daemon.TracePacketResponse
(*SubscribeRequest)(nil), // 57: daemon.SubscribeRequest
(*SystemEvent)(nil), // 58: daemon.SystemEvent
(*GetEventsRequest)(nil), // 59: daemon.GetEventsRequest
(*GetEventsResponse)(nil), // 60: daemon.GetEventsResponse
(*SwitchProfileRequest)(nil), // 61: daemon.SwitchProfileRequest
(*SwitchProfileResponse)(nil), // 62: daemon.SwitchProfileResponse
(*SetConfigRequest)(nil), // 63: daemon.SetConfigRequest
(*SetConfigResponse)(nil), // 64: daemon.SetConfigResponse
(*AddProfileRequest)(nil), // 65: daemon.AddProfileRequest
(*AddProfileResponse)(nil), // 66: daemon.AddProfileResponse
(*RemoveProfileRequest)(nil), // 67: daemon.RemoveProfileRequest
(*RemoveProfileResponse)(nil), // 68: daemon.RemoveProfileResponse
(*ListProfilesRequest)(nil), // 69: daemon.ListProfilesRequest
(*ListProfilesResponse)(nil), // 70: daemon.ListProfilesResponse
(*Profile)(nil), // 71: daemon.Profile
(*GetActiveProfileRequest)(nil), // 72: daemon.GetActiveProfileRequest
(*GetActiveProfileResponse)(nil), // 73: daemon.GetActiveProfileResponse
(*LogoutRequest)(nil), // 74: daemon.LogoutRequest
(*LogoutResponse)(nil), // 75: daemon.LogoutResponse
(*GetFeaturesRequest)(nil), // 76: daemon.GetFeaturesRequest
(*GetFeaturesResponse)(nil), // 77: daemon.GetFeaturesResponse
(*GetPeerSSHHostKeyRequest)(nil), // 78: daemon.GetPeerSSHHostKeyRequest
(*GetPeerSSHHostKeyResponse)(nil), // 79: daemon.GetPeerSSHHostKeyResponse
(*RequestJWTAuthRequest)(nil), // 80: daemon.RequestJWTAuthRequest
(*RequestJWTAuthResponse)(nil), // 81: daemon.RequestJWTAuthResponse
(*WaitJWTTokenRequest)(nil), // 82: daemon.WaitJWTTokenRequest
(*WaitJWTTokenResponse)(nil), // 83: daemon.WaitJWTTokenResponse
(*StartCPUProfileRequest)(nil), // 84: daemon.StartCPUProfileRequest
(*StartCPUProfileResponse)(nil), // 85: daemon.StartCPUProfileResponse
(*StopCPUProfileRequest)(nil), // 86: daemon.StopCPUProfileRequest
(*StopCPUProfileResponse)(nil), // 87: daemon.StopCPUProfileResponse
(*InstallerResultRequest)(nil), // 88: daemon.InstallerResultRequest
(*InstallerResultResponse)(nil), // 89: daemon.InstallerResultResponse
(*ExposeServiceRequest)(nil), // 90: daemon.ExposeServiceRequest
(*ExposeServiceEvent)(nil), // 91: daemon.ExposeServiceEvent
(*ExposeServiceReady)(nil), // 92: daemon.ExposeServiceReady
nil, // 93: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 94: daemon.PortInfo.Range
nil, // 95: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 96: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 97: google.protobuf.Timestamp
}
var file_daemon_proto_depIdxs = []int32{
1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
92, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
93, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
93, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
92, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
20, // 9: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
19, // 10: daemon.FullStatus.peers:type_name -> daemon.PeerState
23, // 11: daemon.FullStatus.relays:type_name -> daemon.RelayState
24, // 12: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
89, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
90, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
2, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
96, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
28, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
97, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
97, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
96, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
26, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
23, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
22, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
21, // 9: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
20, // 10: daemon.FullStatus.peers:type_name -> daemon.PeerState
24, // 11: daemon.FullStatus.relays:type_name -> daemon.RelayState
25, // 12: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
58, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
27, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
34, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
93, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
94, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
35, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
35, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
36, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
0, // 21: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel
0, // 22: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
43, // 23: daemon.ListStatesResponse.states:type_name -> daemon.State
52, // 24: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags
54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
93, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
91, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
92, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
9, // 35: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
11, // 36: daemon.DaemonService.Up:input_type -> daemon.UpRequest
13, // 37: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
15, // 38: daemon.DaemonService.Down:input_type -> daemon.DownRequest
17, // 39: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
28, // 40: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest
30, // 41: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest
30, // 42: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest
4, // 43: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest
37, // 44: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
39, // 45: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
41, // 46: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
44, // 47: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
46, // 48: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
48, // 49: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
50, // 50: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest
53, // 51: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest
56, // 52: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest
58, // 53: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest
60, // 54: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest
62, // 55: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest
64, // 56: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest
66, // 57: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest
68, // 58: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest
71, // 59: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest
73, // 60: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
75, // 61: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
77, // 62: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
83, // 65: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
85, // 66: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
5, // 67: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
87, // 68: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
8, // 69: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
10, // 70: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
12, // 71: daemon.DaemonService.Up:output_type -> daemon.UpResponse
14, // 72: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
16, // 73: daemon.DaemonService.Down:output_type -> daemon.DownResponse
18, // 74: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
29, // 75: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
31, // 76: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
31, // 77: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
36, // 78: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
38, // 79: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
40, // 80: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
42, // 81: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
45, // 82: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
47, // 83: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
49, // 84: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
51, // 85: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
55, // 86: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
57, // 87: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
59, // 88: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
61, // 89: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
63, // 90: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
65, // 91: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
67, // 92: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
69, // 93: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
72, // 94: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
74, // 95: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
76, // 96: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
78, // 97: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
80, // 98: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
82, // 99: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
84, // 100: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
86, // 101: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
6, // 102: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
88, // 103: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
69, // [69:104] is the sub-list for method output_type
34, // [34:69] is the sub-list for method input_type
34, // [34:34] is the sub-list for extension type_name
34, // [34:34] is the sub-list for extension extendee
0, // [0:34] is the sub-list for field type_name
44, // 23: daemon.ListStatesResponse.states:type_name -> daemon.State
53, // 24: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags
55, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
3, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
4, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
97, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
95, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
58, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
96, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
71, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
1, // 33: daemon.ExposeServiceRequest.protocol:type_name -> daemon.ExposeProtocol
92, // 34: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady
33, // 35: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
8, // 36: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
10, // 37: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
12, // 38: daemon.DaemonService.Up:input_type -> daemon.UpRequest
14, // 39: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
16, // 40: daemon.DaemonService.Down:input_type -> daemon.DownRequest
18, // 41: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
29, // 42: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest
31, // 43: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest
31, // 44: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest
5, // 45: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest
38, // 46: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
40, // 47: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
42, // 48: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
45, // 49: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
47, // 50: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
49, // 51: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
51, // 52: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest
54, // 53: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest
57, // 54: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest
59, // 55: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest
61, // 56: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest
63, // 57: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest
65, // 58: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest
67, // 59: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest
69, // 60: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest
72, // 61: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest
74, // 62: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
76, // 63: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
78, // 64: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
80, // 65: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
82, // 66: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
84, // 67: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
86, // 68: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
6, // 69: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
88, // 70: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
90, // 71: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest
9, // 72: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
11, // 73: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
13, // 74: daemon.DaemonService.Up:output_type -> daemon.UpResponse
15, // 75: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
17, // 76: daemon.DaemonService.Down:output_type -> daemon.DownResponse
19, // 77: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
30, // 78: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
32, // 79: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
32, // 80: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
37, // 81: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
39, // 82: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
41, // 83: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
43, // 84: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
46, // 85: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
48, // 86: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
50, // 87: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
52, // 88: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
56, // 89: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
58, // 90: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
60, // 91: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
62, // 92: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
64, // 93: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
66, // 94: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
68, // 95: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
70, // 96: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
73, // 97: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
75, // 98: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
77, // 99: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
79, // 100: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
81, // 101: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
83, // 102: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
85, // 103: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
87, // 104: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
7, // 105: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
89, // 106: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
91, // 107: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent
72, // [72:108] is the sub-list for method output_type
36, // [36:72] is the sub-list for method input_type
36, // [36:36] is the sub-list for extension type_name
36, // [36:36] is the sub-list for extension extendee
0, // [0:36] is the sub-list for field type_name
}
func init() { file_daemon_proto_init() }
@@ -6439,13 +6743,16 @@ func file_daemon_proto_init() {
file_daemon_proto_msgTypes[58].OneofWrappers = []any{}
file_daemon_proto_msgTypes[69].OneofWrappers = []any{}
file_daemon_proto_msgTypes[75].OneofWrappers = []any{}
file_daemon_proto_msgTypes[86].OneofWrappers = []any{
(*ExposeServiceEvent_Ready)(nil),
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
NumEnums: 4,
NumMessages: 88,
NumEnums: 5,
NumMessages: 91,
NumExtensions: 0,
NumServices: 1,
},

View File

@@ -103,6 +103,9 @@ service DaemonService {
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
// ExposeService exposes a local port via the NetBird reverse proxy
rpc ExposeService(ExposeServiceRequest) returns (stream ExposeServiceEvent) {}
}
@@ -801,3 +804,32 @@ message InstallerResultResponse {
bool success = 1;
string errorMsg = 2;
}
enum ExposeProtocol {
EXPOSE_HTTP = 0;
EXPOSE_HTTPS = 1;
EXPOSE_TCP = 2;
EXPOSE_UDP = 3;
}
message ExposeServiceRequest {
uint32 port = 1;
ExposeProtocol protocol = 2;
string pin = 3;
string password = 4;
repeated string user_groups = 5;
string domain = 6;
string name_prefix = 7;
}
message ExposeServiceEvent {
oneof event {
ExposeServiceReady ready = 1;
}
}
message ExposeServiceReady {
string service_name = 1;
string service_url = 2;
string domain = 3;
}

View File

@@ -76,6 +76,8 @@ type DaemonServiceClient interface {
StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error)
NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
// ExposeService exposes a local port via the NetBird reverse proxy
ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (DaemonService_ExposeServiceClient, error)
}
type daemonServiceClient struct {
@@ -424,6 +426,38 @@ func (c *daemonServiceClient) GetInstallerResult(ctx context.Context, in *Instal
return out, nil
}
func (c *daemonServiceClient) ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (DaemonService_ExposeServiceClient, error) {
stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[1], "/daemon.DaemonService/ExposeService", opts...)
if err != nil {
return nil, err
}
x := &daemonServiceExposeServiceClient{stream}
if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err
}
if err := x.ClientStream.CloseSend(); err != nil {
return nil, err
}
return x, nil
}
type DaemonService_ExposeServiceClient interface {
Recv() (*ExposeServiceEvent, error)
grpc.ClientStream
}
type daemonServiceExposeServiceClient struct {
grpc.ClientStream
}
func (x *daemonServiceExposeServiceClient) Recv() (*ExposeServiceEvent, error) {
m := new(ExposeServiceEvent)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility
@@ -486,6 +520,8 @@ type DaemonServiceServer interface {
StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error)
NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
// ExposeService exposes a local port via the NetBird reverse proxy
ExposeService(*ExposeServiceRequest, DaemonService_ExposeServiceServer) error
mustEmbedUnimplementedDaemonServiceServer()
}
@@ -598,6 +634,9 @@ func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLi
func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetInstallerResult not implemented")
}
func (UnimplementedDaemonServiceServer) ExposeService(*ExposeServiceRequest, DaemonService_ExposeServiceServer) error {
return status.Errorf(codes.Unimplemented, "method ExposeService not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -1244,6 +1283,27 @@ func _DaemonService_GetInstallerResult_Handler(srv interface{}, ctx context.Cont
return interceptor(ctx, in, info, handler)
}
func _DaemonService_ExposeService_Handler(srv interface{}, stream grpc.ServerStream) error {
m := new(ExposeServiceRequest)
if err := stream.RecvMsg(m); err != nil {
return err
}
return srv.(DaemonServiceServer).ExposeService(m, &daemonServiceExposeServiceServer{stream})
}
type DaemonService_ExposeServiceServer interface {
Send(*ExposeServiceEvent) error
grpc.ServerStream
}
type daemonServiceExposeServiceServer struct {
grpc.ServerStream
}
func (x *daemonServiceExposeServiceServer) Send(m *ExposeServiceEvent) error {
return x.ServerStream.SendMsg(m)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -1394,6 +1454,11 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
Handler: _DaemonService_SubscribeEvents_Handler,
ServerStreams: true,
},
{
StreamName: "ExposeService",
Handler: _DaemonService_ExposeService_Handler,
ServerStreams: true,
},
},
Metadata: "daemon.proto",
}

View File

@@ -1,77 +0,0 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
)
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
switch req.GetType() {
case proto.OSLifecycleRequest_WAKEUP:
return s.handleWakeUp(callerCtx)
case proto.OSLifecycleRequest_SLEEP:
return s.handleSleep(callerCtx)
default:
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
}
return &proto.OSLifecycleResponse{}, nil
}
// handleWakeUp processes a wake-up event by triggering the Up command if the system was previously put to sleep.
// It resets the sleep state and logs the process. Returns a response or an error if the Up command fails.
func (s *Server) handleWakeUp(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
if !s.sleepTriggeredDown.Load() {
log.Info("skipping up because wasn't sleep down")
return &proto.OSLifecycleResponse{}, nil
}
// avoid other wakeup runs if sleep didn't make the computer sleep
s.sleepTriggeredDown.Store(false)
log.Info("running up after wake up")
_, err := s.Up(callerCtx, &proto.UpRequest{})
if err != nil {
log.Errorf("running up failed: %v", err)
return &proto.OSLifecycleResponse{}, err
}
log.Info("running up command executed successfully")
return &proto.OSLifecycleResponse{}, nil
}
// handleSleep handles the sleep event by initiating a "down" sequence if the system is in a connected or connecting state.
func (s *Server) handleSleep(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
s.mutex.Lock()
state := internal.CtxGetState(s.rootCtx)
status, err := state.Status()
if err != nil {
s.mutex.Unlock()
return &proto.OSLifecycleResponse{}, err
}
if status != internal.StatusConnecting && status != internal.StatusConnected {
log.Infof("skipping setting the agent down because status is %s", status)
s.mutex.Unlock()
return &proto.OSLifecycleResponse{}, nil
}
s.mutex.Unlock()
log.Info("running down after system started sleeping")
_, err = s.Down(callerCtx, &proto.DownRequest{})
if err != nil {
log.Errorf("running down failed: %v", err)
return &proto.OSLifecycleResponse{}, err
}
s.sleepTriggeredDown.Store(true)
log.Info("running down executed successfully")
return &proto.OSLifecycleResponse{}, nil
}

View File

@@ -1,219 +0,0 @@
package server
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
)
func newTestServer() *Server {
ctx := internal.CtxInitState(context.Background())
return &Server{
rootCtx: ctx,
statusRecorder: peer.NewRecorder(""),
}
}
func TestNotifyOSLifecycle_WakeUp_SkipsWhenNotSleepTriggered(t *testing.T) {
s := newTestServer()
// sleepTriggeredDown is false by default
assert.False(t, s.sleepTriggeredDown.Load())
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false")
}
func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusIdle(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_SLEEP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is Idle")
}
func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusNeedsLogin(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusNeedsLogin)
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_SLEEP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is NeedsLogin")
}
func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnecting(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusConnecting)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.actCancel = cancel
resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_SLEEP,
})
require.NoError(t, err)
assert.NotNil(t, resp, "handleSleep returns not nil response on success")
assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connecting")
}
func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnected(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusConnected)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.actCancel = cancel
resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_SLEEP,
})
require.NoError(t, err)
assert.NotNil(t, resp, "handleSleep returns not nil response on success")
assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connected")
}
func TestNotifyOSLifecycle_WakeUp_ResetsFlag(t *testing.T) {
s := newTestServer()
// Manually set the flag to simulate prior sleep down
s.sleepTriggeredDown.Store(true)
// WakeUp will try to call Up which fails without proper setup, but flag should reset first
_, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
assert.False(t, s.sleepTriggeredDown.Load(), "flag should be reset after WakeUp attempt")
}
func TestNotifyOSLifecycle_MultipleWakeUpCalls(t *testing.T) {
s := newTestServer()
// First wakeup without prior sleep - should be no-op
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load())
// Simulate prior sleep
s.sleepTriggeredDown.Store(true)
// First wakeup after sleep - should reset flag
_, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
assert.False(t, s.sleepTriggeredDown.Load())
// Second wakeup - should be no-op
resp, err = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load())
}
func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) {
s := newTestServer()
resp, err := s.handleWakeUp(context.Background())
require.NoError(t, err)
require.NotNil(t, resp)
}
func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) {
s := newTestServer()
s.sleepTriggeredDown.Store(true)
// Even if Up fails, flag should be reset
_, _ = s.handleWakeUp(context.Background())
assert.False(t, s.sleepTriggeredDown.Load(), "flag must be reset before calling Up")
}
func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) {
tests := []struct {
name string
status internal.StatusType
}{
{"Idle", internal.StatusIdle},
{"NeedsLogin", internal.StatusNeedsLogin},
{"LoginFailed", internal.StatusLoginFailed},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(tt.status)
resp, err := s.handleSleep(context.Background())
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load())
})
}
}
func TestHandleSleep_ProceedsForActiveStates(t *testing.T) {
tests := []struct {
name string
status internal.StatusType
}{
{"Connecting", internal.StatusConnecting},
{"Connected", internal.StatusConnected},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(tt.status)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.actCancel = cancel
resp, err := s.handleSleep(ctx)
require.NoError(t, err)
assert.NotNil(t, resp)
assert.True(t, s.sleepTriggeredDown.Load())
})
}
}

View File

@@ -21,7 +21,9 @@ import (
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/expose"
"github.com/netbirdio/netbird/client/internal/profilemanager"
sleephandler "github.com/netbirdio/netbird/client/internal/sleep/handler"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/domain"
@@ -85,8 +87,7 @@ type Server struct {
profilesDisabled bool
updateSettingsDisabled bool
// sleepTriggeredDown holds a state indicated if the sleep handler triggered the last client down
sleepTriggeredDown atomic.Bool
sleepHandler *sleephandler.SleepHandler
jwtCache *jwtCache
}
@@ -100,7 +101,7 @@ type oauthAuthFlow struct {
// New server instance constructor.
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool) *Server {
return &Server{
s := &Server{
rootCtx: ctx,
logFile: logFile,
persistSyncResponse: true,
@@ -110,6 +111,10 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
updateSettingsDisabled: updateSettingsDisabled,
jwtCache: newJWTCache(),
}
agent := &serverAgent{s}
s.sleepHandler = sleephandler.New(agent)
return s
}
func (s *Server) Start() error {
@@ -636,8 +641,6 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
return s.waitForUp(callerCtx)
}
defer s.mutex.Unlock()
if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil {
log.Warnf(errRestoreResidualState, err)
}
@@ -649,10 +652,12 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
// not in the progress or already successfully established connection.
status, err := state.Status()
if err != nil {
s.mutex.Unlock()
return nil, err
}
if status != internal.StatusIdle {
s.mutex.Unlock()
return nil, fmt.Errorf("up already in progress: current status %s", status)
}
@@ -669,17 +674,20 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
s.actCancel = cancel
if s.config == nil {
s.mutex.Unlock()
return nil, fmt.Errorf("config is not defined, please call login command first")
}
activeProf, err := s.profileManager.GetActiveProfileState()
if err != nil {
s.mutex.Unlock()
log.Errorf("failed to get active profile state: %v", err)
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
if msg != nil && msg.ProfileName != nil {
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
s.mutex.Unlock()
log.Errorf("failed to switch profile: %v", err)
return nil, fmt.Errorf("failed to switch profile: %w", err)
}
@@ -687,6 +695,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
activeProf, err = s.profileManager.GetActiveProfileState()
if err != nil {
s.mutex.Unlock()
log.Errorf("failed to get active profile state: %v", err)
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
@@ -695,6 +704,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
config, _, err := s.getConfig(activeProf)
if err != nil {
s.mutex.Unlock()
log.Errorf("failed to get active profile config: %v", err)
return nil, fmt.Errorf("failed to get active profile config: %w", err)
}
@@ -713,6 +723,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
}
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan)
s.mutex.Unlock()
return s.waitForUp(callerCtx)
}
@@ -838,14 +849,26 @@ func (s *Server) cleanupConnection() error {
if s.actCancel == nil {
return ErrServiceNotUp
}
// Capture the engine reference before cancelling the context.
// After actCancel(), the connectWithRetryRuns goroutine wakes up
// and sets connectClient.engine = nil, causing connectClient.Stop()
// to skip the engine shutdown entirely.
var engine *internal.Engine
if s.connectClient != nil {
engine = s.connectClient.Engine()
}
s.actCancel()
if s.connectClient == nil {
return nil
}
if err := s.connectClient.Stop(); err != nil {
return err
if engine != nil {
if err := engine.Stop(); err != nil {
return err
}
}
s.connectClient = nil
@@ -1312,6 +1335,60 @@ func (s *Server) WaitJWTToken(
}, nil
}
// ExposeService exposes a local port via the NetBird reverse proxy.
func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.DaemonService_ExposeServiceServer) error {
s.mutex.Lock()
if !s.clientRunning {
s.mutex.Unlock()
return gstatus.Errorf(codes.FailedPrecondition, "client is not running, run 'netbird up' first")
}
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return gstatus.Errorf(codes.FailedPrecondition, "client not initialized")
}
engine := connectClient.Engine()
if engine == nil {
return gstatus.Errorf(codes.FailedPrecondition, "engine not initialized")
}
mgr := engine.GetExposeManager()
if mgr == nil {
return gstatus.Errorf(codes.Internal, "expose manager not available")
}
ctx := srv.Context()
exposeCtx, exposeCancel := context.WithTimeout(ctx, 30*time.Second)
defer exposeCancel()
mgmReq := expose.NewRequest(req)
result, err := mgr.Expose(exposeCtx, *mgmReq)
if err != nil {
return err
}
if err := srv.Send(&proto.ExposeServiceEvent{
Event: &proto.ExposeServiceEvent_Ready{
Ready: &proto.ExposeServiceReady{
ServiceName: result.ServiceName,
ServiceUrl: result.ServiceURL,
Domain: result.Domain,
},
},
}); err != nil {
return err
}
err = mgr.KeepAlive(ctx, result.Domain)
if err != nil {
return err
}
return nil
}
func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
return false
@@ -1548,9 +1625,14 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest)
func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}) error {
log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate)
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
if err := s.connectClient.Run(runningChan, s.logFile); err != nil {
client := internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate)
client.SetSyncResponsePersistence(s.persistSyncResponse)
s.mutex.Lock()
s.connectClient = client
s.mutex.Unlock()
if err := client.Run(runningChan, s.logFile); err != nil {
return err
}
return nil

View File

@@ -0,0 +1,187 @@
package server
import (
"context"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
)
func newTestServer() *Server {
return &Server{
rootCtx: context.Background(),
statusRecorder: peer.NewRecorder(""),
}
}
func newDummyConnectClient(ctx context.Context) *internal.ConnectClient {
return internal.NewConnectClient(ctx, nil, nil, false)
}
// TestConnectSetsClientWithMutex validates that connect() sets s.connectClient
// under mutex protection so concurrent readers see a consistent value.
func TestConnectSetsClientWithMutex(t *testing.T) {
s := newTestServer()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Manually simulate what connect() does (without calling Run which panics without full setup)
client := newDummyConnectClient(ctx)
s.mutex.Lock()
s.connectClient = client
s.mutex.Unlock()
// Verify the assignment is visible under mutex
s.mutex.Lock()
assert.Equal(t, client, s.connectClient, "connectClient should be set")
s.mutex.Unlock()
}
// TestConcurrentConnectClientAccess validates that concurrent reads of
// s.connectClient under mutex don't race with a write.
func TestConcurrentConnectClientAccess(t *testing.T) {
s := newTestServer()
ctx := context.Background()
client := newDummyConnectClient(ctx)
var wg sync.WaitGroup
nilCount := 0
setCount := 0
var mu sync.Mutex
// Start readers
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
s.mutex.Lock()
c := s.connectClient
s.mutex.Unlock()
mu.Lock()
defer mu.Unlock()
if c == nil {
nilCount++
} else {
setCount++
}
}()
}
// Simulate connect() writing under mutex
time.Sleep(5 * time.Millisecond)
s.mutex.Lock()
s.connectClient = client
s.mutex.Unlock()
wg.Wait()
assert.Equal(t, 50, nilCount+setCount, "all goroutines should complete without panic")
}
// TestCleanupConnection_ClearsConnectClient validates that cleanupConnection
// properly nils out connectClient.
func TestCleanupConnection_ClearsConnectClient(t *testing.T) {
s := newTestServer()
_, cancel := context.WithCancel(context.Background())
s.actCancel = cancel
s.connectClient = newDummyConnectClient(context.Background())
s.clientRunning = true
err := s.cleanupConnection()
require.NoError(t, err)
assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup")
}
// TestCleanState_NilConnectClient validates that CleanState doesn't panic
// when connectClient is nil.
func TestCleanState_NilConnectClient(t *testing.T) {
s := newTestServer()
s.connectClient = nil
s.profileManager = nil // will cause error if it tries to proceed past the nil check
// Should not panic — the nil check should prevent calling Status() on nil
assert.NotPanics(t, func() {
_, _ = s.CleanState(context.Background(), &proto.CleanStateRequest{All: true})
})
}
// TestDeleteState_NilConnectClient validates that DeleteState doesn't panic
// when connectClient is nil.
func TestDeleteState_NilConnectClient(t *testing.T) {
s := newTestServer()
s.connectClient = nil
s.profileManager = nil
assert.NotPanics(t, func() {
_, _ = s.DeleteState(context.Background(), &proto.DeleteStateRequest{All: true})
})
}
// TestDownThenUp_StaleRunningChan documents the known state issue where
// clientRunningChan from a previous connection is already closed, causing
// waitForUp() to return immediately on reconnect.
func TestDownThenUp_StaleRunningChan(t *testing.T) {
s := newTestServer()
// Simulate state after a successful connection
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
close(s.clientRunningChan) // closed when engine started
s.clientGiveUpChan = make(chan struct{})
s.connectClient = newDummyConnectClient(context.Background())
_, cancel := context.WithCancel(context.Background())
s.actCancel = cancel
// Simulate Down(): cleanupConnection sets connectClient = nil
s.mutex.Lock()
err := s.cleanupConnection()
s.mutex.Unlock()
require.NoError(t, err)
// After cleanup: connectClient is nil, clientRunning still true
// (goroutine hasn't exited yet)
s.mutex.Lock()
assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup")
assert.True(t, s.clientRunning, "clientRunning still true until goroutine exits")
s.mutex.Unlock()
// waitForUp() returns immediately due to stale closed clientRunningChan
ctx, ctxCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer ctxCancel()
waitDone := make(chan error, 1)
go func() {
_, err := s.waitForUp(ctx)
waitDone <- err
}()
select {
case err := <-waitDone:
assert.NoError(t, err, "waitForUp returns success on stale channel")
// But connectClient is still nil — this is the stale state issue
s.mutex.Lock()
assert.Nil(t, s.connectClient, "connectClient is nil despite waitForUp success")
s.mutex.Unlock()
case <-time.After(1 * time.Second):
t.Fatal("waitForUp should have returned immediately due to stale closed channel")
}
}
// TestConnectClient_EngineNilOnFreshClient validates that a newly created
// ConnectClient has nil Engine (before Run is called).
func TestConnectClient_EngineNilOnFreshClient(t *testing.T) {
client := newDummyConnectClient(context.Background())
assert.Nil(t, client.Engine(), "engine should be nil on fresh ConnectClient")
}

46
client/server/sleep.go Normal file
View File

@@ -0,0 +1,46 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
)
// serverAgent adapts Server to the handler.Agent and handler.StatusChecker interfaces
type serverAgent struct {
s *Server
}
func (a *serverAgent) Up(ctx context.Context) error {
_, err := a.s.Up(ctx, &proto.UpRequest{})
return err
}
func (a *serverAgent) Down(ctx context.Context) error {
_, err := a.s.Down(ctx, &proto.DownRequest{})
return err
}
func (a *serverAgent) Status() (internal.StatusType, error) {
return internal.CtxGetState(a.s.rootCtx).Status()
}
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
switch req.GetType() {
case proto.OSLifecycleRequest_WAKEUP:
if err := s.sleepHandler.HandleWakeUp(callerCtx); err != nil {
return &proto.OSLifecycleResponse{}, err
}
case proto.OSLifecycleRequest_SLEEP:
if err := s.sleepHandler.HandleSleep(callerCtx); err != nil {
return &proto.OSLifecycleResponse{}, err
}
default:
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
}
return &proto.OSLifecycleResponse{}, nil
}

View File

@@ -39,7 +39,7 @@ func (s *Server) ListStates(_ context.Context, _ *proto.ListStatesRequest) (*pro
// CleanState handles cleaning of states (performing cleanup operations)
func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) (*proto.CleanStateResponse, error) {
if s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting {
if s.connectClient != nil && (s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting) {
return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.")
}
@@ -82,7 +82,7 @@ func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) (
// DeleteState handles deletion of states without cleanup
func (s *Server) DeleteState(ctx context.Context, req *proto.DeleteStateRequest) (*proto.DeleteStateResponse, error) {
if s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting {
if s.connectClient != nil && (s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting) {
return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.")
}

View File

@@ -19,6 +19,7 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal/daemonaddr"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
@@ -268,7 +269,7 @@ func getDefaultDaemonAddr() string {
if runtime.GOOS == "windows" {
return DefaultDaemonAddrWindows
}
return DefaultDaemonAddr
return daemonaddr.ResolveUnixDaemonAddr(DefaultDaemonAddr)
}
// DialOptions contains options for SSH connections

View File

@@ -46,8 +46,10 @@ const (
cmdSFTP = "<sftp>"
cmdNonInteractive = "<idle>"
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server
DefaultJWTMaxTokenAge = 5 * 60
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server.
// Set to 10 minutes to accommodate identity providers like Azure Entra ID
// that backdate the iat claim by up to 5 minutes.
DefaultJWTMaxTokenAge = 10 * 60
)
var (

View File

@@ -323,7 +323,7 @@ type serviceClient struct {
exitNodeMu sync.Mutex
mExitNodeItems []menuHandler
exitNodeStates []exitNodeState
exitNodeRetryCancel context.CancelFunc
mExitNodeDeselectAll *systray.MenuItem
logFile string
wLoginURL fyne.Window
@@ -924,7 +924,7 @@ func (s *serviceClient) updateStatus() error {
s.mDown.Enable()
s.mNetworks.Enable()
s.mExitNode.Enable()
go s.updateExitNodes()
s.startExitNodeRefresh()
systrayIconState = true
case status.Status == string(internal.StatusConnecting):
s.setConnectingStatus()
@@ -985,6 +985,7 @@ func (s *serviceClient) setDisconnectedStatus() {
s.mUp.Enable()
s.mNetworks.Disable()
s.mExitNode.Disable()
s.cancelExitNodeRetry()
go s.updateExitNodes()
}

View File

@@ -100,8 +100,7 @@ func (h *eventHandler) handleConnectClick() {
func (h *eventHandler) handleDisconnectClick() {
h.client.mDown.Disable()
h.client.exitNodeStates = []exitNodeState{}
h.client.cancelExitNodeRetry()
if h.client.connectCancel != nil {
log.Debugf("cancelling ongoing connect operation")

View File

@@ -6,7 +6,6 @@ import (
"context"
"fmt"
"runtime"
"slices"
"sort"
"strings"
"time"
@@ -34,11 +33,6 @@ const (
type filter string
type exitNodeState struct {
id string
selected bool
}
func (s *serviceClient) showNetworksUI() {
s.wNetworks = s.app.NewWindow("Networks")
s.wNetworks.SetOnClosed(s.cancel)
@@ -335,16 +329,75 @@ func (s *serviceClient) updateNetworksBasedOnDisplayTab(tabs *container.AppTabs,
s.updateNetworks(grid, f)
}
func (s *serviceClient) updateExitNodes() {
// startExitNodeRefresh initiates exit node menu refresh after connecting.
// On Windows, TrayOpenedCh is not supported by the systray library, so we use
// a background poller to keep exit nodes in sync while connected.
// On macOS/Linux, TrayOpenedCh handles refreshes on each tray open.
func (s *serviceClient) startExitNodeRefresh() {
s.cancelExitNodeRetry()
if runtime.GOOS == "windows" {
ctx, cancel := context.WithCancel(s.ctx)
s.exitNodeMu.Lock()
s.exitNodeRetryCancel = cancel
s.exitNodeMu.Unlock()
go s.pollExitNodes(ctx)
} else {
go s.updateExitNodes()
}
}
func (s *serviceClient) cancelExitNodeRetry() {
s.exitNodeMu.Lock()
if s.exitNodeRetryCancel != nil {
s.exitNodeRetryCancel()
s.exitNodeRetryCancel = nil
}
s.exitNodeMu.Unlock()
}
// pollExitNodes periodically refreshes exit nodes while connected.
// Uses a short initial interval to catch routes from the management sync,
// then switches to a longer interval for ongoing updates.
func (s *serviceClient) pollExitNodes(ctx context.Context) {
// Initial fast polling to catch routes as they appear after connect.
for i := 0; i < 5; i++ {
if s.updateExitNodes() {
break
}
select {
case <-ctx.Done():
return
case <-time.After(2 * time.Second):
}
}
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.updateExitNodes()
}
}
}
// updateExitNodes fetches exit nodes from the daemon and recreates the menu.
// Returns true if exit nodes were found.
func (s *serviceClient) updateExitNodes() bool {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Errorf("get client: %v", err)
return
return false
}
exitNodes, err := s.getExitNodes(conn)
if err != nil {
log.Errorf("get exit nodes: %v", err)
return
return false
}
s.exitNodeMu.Lock()
@@ -354,28 +407,14 @@ func (s *serviceClient) updateExitNodes() {
if len(s.mExitNodeItems) > 0 {
s.mExitNode.Enable()
} else {
s.mExitNode.Disable()
return true
}
s.mExitNode.Disable()
return false
}
func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) {
var exitNodeIDs []exitNodeState
for _, node := range exitNodes {
exitNodeIDs = append(exitNodeIDs, exitNodeState{
id: node.ID,
selected: node.Selected,
})
}
sort.Slice(exitNodeIDs, func(i, j int) bool {
return exitNodeIDs[i].id < exitNodeIDs[j].id
})
if slices.Equal(s.exitNodeStates, exitNodeIDs) {
log.Debug("Exit node menu already up to date")
return
}
for _, node := range s.mExitNodeItems {
node.cancel()
node.Hide()
@@ -413,8 +452,6 @@ func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) {
go s.handleChecked(ctx, node.ID, menuItem)
}
s.exitNodeStates = exitNodeIDs
if showDeselectAll {
s.mExitNode.AddSeparator()
deselectAllItem := s.mExitNode.AddSubMenuItem("Deselect All", "Deselect All")

View File

@@ -7,6 +7,7 @@ import (
"net/netip"
"os"
"path"
"path/filepath"
"strings"
"time"
@@ -70,6 +71,8 @@ type ServerConfig struct {
DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"`
Auth AuthConfig `yaml:"auth"`
Store StoreConfig `yaml:"store"`
ActivityStore StoreConfig `yaml:"activityStore"`
AuthStore StoreConfig `yaml:"authStore"`
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
}
@@ -170,7 +173,8 @@ type RelaysConfig struct {
type StoreConfig struct {
Engine string `yaml:"engine"`
EncryptionKey string `yaml:"encryptionKey"`
DSN string `yaml:"dsn"` // Connection string for postgres or mysql engines
DSN string `yaml:"dsn"` // Connection string for postgres or mysql engines
File string `yaml:"file"` // SQLite database file path (optional, defaults to dataDir)
}
// ReverseProxyConfig contains reverse proxy settings
@@ -532,6 +536,74 @@ func stripSignalProtocol(uri string) string {
return uri
}
func buildRelayConfig(relays RelaysConfig) (*nbconfig.Relay, error) {
var ttl time.Duration
if relays.CredentialsTTL != "" {
var err error
ttl, err = time.ParseDuration(relays.CredentialsTTL)
if err != nil {
return nil, fmt.Errorf("invalid relay credentials TTL %q: %w", relays.CredentialsTTL, err)
}
}
return &nbconfig.Relay{
Addresses: relays.Addresses,
CredentialsTTL: util.Duration{Duration: ttl},
Secret: relays.Secret,
}, nil
}
// buildEmbeddedIdPConfig builds the embedded IdP configuration.
// authStore overrides auth.storage when set.
func (c *CombinedConfig) buildEmbeddedIdPConfig(mgmt ManagementConfig) (*idp.EmbeddedIdPConfig, error) {
authStorageType := mgmt.Auth.Storage.Type
authStorageDSN := c.Server.AuthStore.DSN
if c.Server.AuthStore.Engine != "" {
authStorageType = c.Server.AuthStore.Engine
}
if authStorageType == "" {
authStorageType = "sqlite3"
}
authStorageFile := ""
if authStorageType == "postgres" {
if authStorageDSN == "" {
return nil, fmt.Errorf("authStore.dsn is required when authStore.engine is postgres")
}
} else {
authStorageFile = path.Join(mgmt.DataDir, "idp.db")
if c.Server.AuthStore.File != "" {
authStorageFile = c.Server.AuthStore.File
if !filepath.IsAbs(authStorageFile) {
authStorageFile = filepath.Join(mgmt.DataDir, authStorageFile)
}
}
}
cfg := &idp.EmbeddedIdPConfig{
Enabled: true,
Issuer: mgmt.Auth.Issuer,
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
Storage: idp.EmbeddedStorageConfig{
Type: authStorageType,
Config: idp.EmbeddedStorageTypeConfig{
File: authStorageFile,
DSN: authStorageDSN,
},
},
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
}
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {
cfg.Owner = &idp.OwnerConfig{
Email: mgmt.Auth.Owner.Email,
Hash: mgmt.Auth.Owner.Password,
}
}
return cfg, nil
}
// ToManagementConfig converts CombinedConfig to management server config
func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
mgmt := c.Management
@@ -550,19 +622,11 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
// Build relay config
var relayConfig *nbconfig.Relay
if len(mgmt.Relays.Addresses) > 0 || mgmt.Relays.Secret != "" {
var ttl time.Duration
if mgmt.Relays.CredentialsTTL != "" {
var err error
ttl, err = time.ParseDuration(mgmt.Relays.CredentialsTTL)
if err != nil {
return nil, fmt.Errorf("invalid relay credentials TTL %q: %w", mgmt.Relays.CredentialsTTL, err)
}
}
relayConfig = &nbconfig.Relay{
Addresses: mgmt.Relays.Addresses,
CredentialsTTL: util.Duration{Duration: ttl},
Secret: mgmt.Relays.Secret,
relay, err := buildRelayConfig(mgmt.Relays)
if err != nil {
return nil, err
}
relayConfig = relay
}
// Build signal config
@@ -598,31 +662,9 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
httpConfig := &nbconfig.HttpServerConfig{}
// Build embedded IDP config (always enabled in combined server)
storageFile := mgmt.Auth.Storage.File
if storageFile == "" {
storageFile = path.Join(mgmt.DataDir, "idp.db")
}
embeddedIdP := &idp.EmbeddedIdPConfig{
Enabled: true,
Issuer: mgmt.Auth.Issuer,
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
Storage: idp.EmbeddedStorageConfig{
Type: mgmt.Auth.Storage.Type,
Config: idp.EmbeddedStorageTypeConfig{
File: storageFile,
},
},
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
}
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {
embeddedIdP.Owner = &idp.OwnerConfig{
Email: mgmt.Auth.Owner.Email,
Hash: mgmt.Auth.Owner.Password, // Will be hashed if plain text
}
embeddedIdP, err := c.buildEmbeddedIdPConfig(mgmt)
if err != nil {
return nil, err
}
// Set HTTP config fields for embedded IDP

View File

@@ -140,6 +140,23 @@ func initializeConfig() error {
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
}
}
if file := config.Server.Store.File; file != "" {
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
}
if engine := config.Server.ActivityStore.Engine; engine != "" {
engineLower := strings.ToLower(engine)
if engineLower == "postgres" && config.Server.ActivityStore.DSN == "" {
return fmt.Errorf("activityStore.dsn is required when activityStore.engine is postgres")
}
os.Setenv("NB_ACTIVITY_EVENT_STORE_ENGINE", engineLower)
if dsn := config.Server.ActivityStore.DSN; dsn != "" {
os.Setenv("NB_ACTIVITY_EVENT_POSTGRES_DSN", dsn)
}
}
if file := config.Server.ActivityStore.File; file != "" {
os.Setenv("NB_ACTIVITY_EVENT_SQLITE_FILE", file)
}
log.Infof("Starting combined NetBird server")
logConfig(config)
@@ -476,9 +493,6 @@ func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) {
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) {
mgmt := cfg.Management
dnsDomain := mgmt.DnsDomain
singleAccModeDomain := dnsDomain
// Extract port from listen address
_, portStr, err := net.SplitHostPort(cfg.Server.ListenAddress)
if err != nil {
@@ -490,8 +504,9 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
mgmtSrv := mgmtServer.NewServer(
&mgmtServer.Config{
NbConfig: mgmtConfig,
DNSDomain: dnsDomain,
MgmtSingleAccModeDomain: singleAccModeDomain,
DNSDomain: "",
MgmtSingleAccModeDomain: "",
AutoResolveDomains: true,
MgmtPort: mgmtPort,
MgmtMetricsPort: cfg.Server.MetricsPort,
DisableMetrics: mgmt.DisableAnonymousMetrics,
@@ -668,8 +683,11 @@ func logEnvVars() {
if strings.HasPrefix(env, "NB_") {
key, _, _ := strings.Cut(env, "=")
value := os.Getenv(key)
if strings.Contains(strings.ToLower(key), "secret") || strings.Contains(strings.ToLower(key), "key") || strings.Contains(strings.ToLower(key), "password") {
keyLower := strings.ToLower(key)
if strings.Contains(keyLower, "secret") || strings.Contains(keyLower, "key") || strings.Contains(keyLower, "password") {
value = maskSecret(value)
} else if strings.Contains(keyLower, "dsn") {
value = maskDSNPassword(value)
}
log.Infof(" %s=%s", key, value)
found = true

View File

@@ -42,6 +42,9 @@ func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Sto
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
}
}
if file := cfg.Server.Store.File; file != "" {
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
}
datadir := cfg.Management.DataDir
engine := types.Engine(cfg.Management.Store.Engine)

View File

@@ -103,6 +103,19 @@ server:
engine: "sqlite" # sqlite, postgres, or mysql
dsn: "" # Connection string for postgres or mysql
encryptionKey: ""
# file: "" # Custom SQLite file path (optional, defaults to {dataDir}/store.db)
# Activity events store configuration (optional, defaults to sqlite in dataDir)
# activityStore:
# engine: "sqlite" # sqlite or postgres
# dsn: "" # Connection string for postgres
# file: "" # Custom SQLite file path (optional, defaults to {dataDir}/events.db)
# Auth (embedded IdP) store configuration (optional, defaults to sqlite3 in dataDir/idp.db)
# authStore:
# engine: "sqlite3" # sqlite3 or postgres
# dsn: "" # Connection string for postgres (e.g., "host=localhost port=5432 user=postgres password=postgres dbname=netbird_idp sslmode=disable")
# file: "" # Custom SQLite file path (optional, defaults to {dataDir}/idp.db)
# Reverse proxy settings (optional)
# reverseProxy:

View File

@@ -5,7 +5,10 @@ import (
"encoding/json"
"fmt"
"log/slog"
"net/url"
"os"
"strconv"
"strings"
"time"
"golang.org/x/crypto/bcrypt"
@@ -195,11 +198,175 @@ func (s *Storage) OpenStorage(logger *slog.Logger) (storage.Storage, error) {
return nil, fmt.Errorf("sqlite3 storage requires 'file' config")
}
return (&sql.SQLite3{File: file}).Open(logger)
case "postgres":
dsn, _ := s.Config["dsn"].(string)
if dsn == "" {
return nil, fmt.Errorf("postgres storage requires 'dsn' config")
}
pg, err := parsePostgresDSN(dsn)
if err != nil {
return nil, fmt.Errorf("invalid postgres DSN: %w", err)
}
return pg.Open(logger)
default:
return nil, fmt.Errorf("unsupported storage type: %s", s.Type)
}
}
// parsePostgresDSN parses a DSN into a sql.Postgres config.
// It accepts both URI format (postgres://user:pass@host:port/dbname?sslmode=disable)
// and libpq key=value format (host=localhost port=5432 dbname=mydb), including quoted values.
func parsePostgresDSN(dsn string) (*sql.Postgres, error) {
var params map[string]string
var err error
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
params, err = parsePostgresURI(dsn)
} else {
params, err = parsePostgresKeyValue(dsn)
}
if err != nil {
return nil, err
}
host := params["host"]
if host == "" {
host = "localhost"
}
var port uint16 = 5432
if p, ok := params["port"]; ok && p != "" {
v, err := strconv.ParseUint(p, 10, 16)
if err != nil {
return nil, fmt.Errorf("invalid port %q: %w", p, err)
}
if v == 0 {
return nil, fmt.Errorf("invalid port %q: must be non-zero", p)
}
port = uint16(v)
}
dbname := params["dbname"]
if dbname == "" {
return nil, fmt.Errorf("dbname is required in DSN")
}
pg := &sql.Postgres{
NetworkDB: sql.NetworkDB{
Host: host,
Port: port,
Database: dbname,
User: params["user"],
Password: params["password"],
},
}
if sslMode := params["sslmode"]; sslMode != "" {
switch sslMode {
case "disable", "allow", "prefer", "require", "verify-ca", "verify-full":
pg.SSL.Mode = sslMode
default:
return nil, fmt.Errorf("unsupported sslmode %q: valid values are disable, allow, prefer, require, verify-ca, verify-full", sslMode)
}
}
return pg, nil
}
// parsePostgresURI parses a postgres:// or postgresql:// URI into parameter key-value pairs.
func parsePostgresURI(dsn string) (map[string]string, error) {
u, err := url.Parse(dsn)
if err != nil {
return nil, fmt.Errorf("invalid postgres URI: %w", err)
}
params := make(map[string]string)
if u.User != nil {
params["user"] = u.User.Username()
if p, ok := u.User.Password(); ok {
params["password"] = p
}
}
if u.Hostname() != "" {
params["host"] = u.Hostname()
}
if u.Port() != "" {
params["port"] = u.Port()
}
dbname := strings.TrimPrefix(u.Path, "/")
if dbname != "" {
params["dbname"] = dbname
}
for k, v := range u.Query() {
if len(v) > 0 {
params[k] = v[0]
}
}
return params, nil
}
// parsePostgresKeyValue parses a libpq key=value DSN string, handling single-quoted values
// (e.g., password='my pass' host=localhost).
func parsePostgresKeyValue(dsn string) (map[string]string, error) {
params := make(map[string]string)
s := strings.TrimSpace(dsn)
for s != "" {
eqIdx := strings.IndexByte(s, '=')
if eqIdx < 0 {
break
}
key := strings.TrimSpace(s[:eqIdx])
value, rest, err := parseDSNValue(s[eqIdx+1:])
if err != nil {
return nil, fmt.Errorf("%w for key %q", err, key)
}
params[key] = value
s = strings.TrimSpace(rest)
}
return params, nil
}
// parseDSNValue parses the next value from a libpq key=value string positioned after the '='.
// It returns the parsed value and the remaining unparsed string.
func parseDSNValue(s string) (value, rest string, err error) {
if len(s) > 0 && s[0] == '\'' {
return parseQuotedDSNValue(s[1:])
}
// Unquoted value: read until whitespace.
idx := strings.IndexAny(s, " \t\n")
if idx < 0 {
return s, "", nil
}
return s[:idx], s[idx:], nil
}
// parseQuotedDSNValue parses a single-quoted value starting after the opening quote.
// Libpq uses ” to represent a literal single quote inside quoted values.
func parseQuotedDSNValue(s string) (value, rest string, err error) {
var buf strings.Builder
for len(s) > 0 {
if s[0] == '\'' {
if len(s) > 1 && s[1] == '\'' {
buf.WriteByte('\'')
s = s[2:]
continue
}
return buf.String(), s[1:], nil
}
buf.WriteByte(s[0])
s = s[1:]
}
return "", "", fmt.Errorf("unterminated quoted value")
}
// Validate validates the configuration
func (c *YAMLConfig) Validate() error {
if c.Issuer == "" {

View File

@@ -63,6 +63,8 @@ type Controller struct {
expNewNetworkMap bool
expNewNetworkMapAIDs map[string]struct{}
compactedNetworkMap bool
}
type bufferUpdate struct {
@@ -85,6 +87,12 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
newNetworkMapBuilder = false
}
compactedNetworkMap, err := strconv.ParseBool(os.Getenv(types.EnvNewNetworkMapCompacted))
if err != nil {
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", types.EnvNewNetworkMapCompacted, err)
compactedNetworkMap = false
}
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
expIDs := make(map[string]struct{}, len(ids))
for _, id := range ids {
@@ -108,6 +116,8 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
holder: types.NewHolder(),
expNewNetworkMap: newNetworkMapBuilder,
expNewNetworkMapAIDs: expIDs,
compactedNetworkMap: compactedNetworkMap,
}
}
@@ -230,9 +240,12 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
var remotePeerNetworkMap *types.NetworkMap
if c.experimentalNetworkMap(accountID) {
switch {
case c.experimentalNetworkMap(accountID):
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
case c.compactedNetworkMap:
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
default:
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
@@ -355,9 +368,12 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
var remotePeerNetworkMap *types.NetworkMap
if c.experimentalNetworkMap(accountId) {
switch {
case c.experimentalNetworkMap(accountId):
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
case c.compactedNetworkMap:
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
default:
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
@@ -479,7 +495,12 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
} else {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers())
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.compactedNetworkMap {
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -854,7 +875,12 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
account.InjectProxyPolicies(ctx)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.compactedNetworkMap {
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
}
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]

View File

@@ -24,6 +24,8 @@ type AccessLogEntry struct {
Reason string
UserId string `gorm:"index"`
AuthMethodUsed string `gorm:"index"`
BytesUpload int64 `gorm:"index"`
BytesDownload int64 `gorm:"index"`
}
// FromProto creates an AccessLogEntry from a proto.AccessLog
@@ -39,6 +41,8 @@ func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) {
a.UserId = serviceLog.GetUserId()
a.AuthMethodUsed = serviceLog.GetAuthMechanism()
a.AccountID = serviceLog.GetAccountId()
a.BytesUpload = serviceLog.GetBytesUpload()
a.BytesDownload = serviceLog.GetBytesDownload()
if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" {
if ip, err := netip.ParseAddr(sourceIP); err == nil {
@@ -101,5 +105,7 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
AuthMethodUsed: authMethod,
CountryCode: countryCode,
CityName: cityName,
BytesUpload: a.BytesUpload,
BytesDownload: a.BytesDownload,
}
}

View File

@@ -15,3 +15,12 @@ type Domain struct {
Type Type `gorm:"-"`
Validated bool
}
// EventMeta returns activity event metadata for a domain
func (d *Domain) EventMeta() map[string]any {
return map[string]any{
"domain": d.Domain,
"target_cluster": d.TargetCluster,
"validated": d.Validated,
}
}

View File

@@ -9,4 +9,5 @@ type Manager interface {
CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*Domain, error)
DeleteDomain(ctx context.Context, accountID, userID, domainID string) error
ValidateDomain(ctx context.Context, accountID, userID, domainID string)
GetClusterDomains() []string
}

View File

@@ -9,6 +9,8 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
@@ -27,25 +29,25 @@ type store interface {
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
}
type proxyURLProvider interface {
GetConnectedProxyURLs() []string
type proxyManager interface {
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
}
type Manager struct {
store store
validator domain.Validator
proxyURLProvider proxyURLProvider
proxyManager proxyManager
permissionsManager permissions.Manager
accountManager account.Manager
}
func NewManager(store store, proxyURLProvider proxyURLProvider, permissionsManager permissions.Manager) Manager {
func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager, accountManager account.Manager) Manager {
return Manager{
store: store,
proxyURLProvider: proxyURLProvider,
validator: domain.Validator{
Resolver: net.DefaultResolver,
},
store: store,
proxyManager: proxyMgr,
validator: domain.Validator{Resolver: net.DefaultResolver},
permissionsManager: permissionsManager,
accountManager: accountManager,
}
}
@@ -67,8 +69,12 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
// Add connected proxy clusters as free domains.
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
allowList := m.proxyURLAllowList()
log.WithFields(log.Fields{
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
return nil, err
}
log.WithContext(ctx).WithFields(log.Fields{
"accountID": accountID,
"proxyAllowList": allowList,
}).Debug("getting domains with proxy allow list")
@@ -107,7 +113,10 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
}
// Verify the target cluster is in the available clusters
allowList := m.proxyURLAllowList()
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
clusterValid := false
for _, cluster := range allowList {
if cluster == targetCluster {
@@ -129,6 +138,9 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
if err != nil {
return d, fmt.Errorf("create domain in store: %w", err)
}
m.accountManager.StoreEvent(ctx, userID, d.ID, accountID, activity.DomainAdded, d.EventMeta())
return d, nil
}
@@ -141,10 +153,18 @@ func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID s
return status.NewPermissionDeniedError()
}
d, err := m.store.GetCustomDomain(ctx, accountID, domainID)
if err != nil {
return fmt.Errorf("get domain from store: %w", err)
}
if err := m.store.DeleteCustomDomain(ctx, accountID, domainID); err != nil {
// TODO: check for "no records" type error. Because that is a success condition.
return fmt.Errorf("delete domain from store: %w", err)
}
m.accountManager.StoreEvent(ctx, userID, domainID, accountID, activity.DomainDeleted, d.EventMeta())
return nil
}
@@ -211,6 +231,8 @@ func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID
}).WithError(err).Error("update custom domain in store")
return
}
m.accountManager.StoreEvent(context.Background(), userID, domainID, accountID, activity.DomainValidated, d.EventMeta())
} else {
log.WithFields(log.Fields{
"accountID": accountID,
@@ -221,21 +243,26 @@ func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID
}
}
// proxyURLAllowList retrieves a list of currently connected proxies and
// their URLs
func (m Manager) proxyURLAllowList() []string {
var reverseProxyAddresses []string
if m.proxyURLProvider != nil {
reverseProxyAddresses = m.proxyURLProvider.GetConnectedProxyURLs()
// GetClusterDomains returns a list of proxy cluster domains.
func (m Manager) GetClusterDomains() []string {
if m.proxyManager == nil {
return nil
}
return reverseProxyAddresses
addresses, err := m.proxyManager.GetActiveClusterAddresses(context.Background())
if err != nil {
return nil
}
return addresses
}
// DeriveClusterFromDomain determines the proxy cluster for a given domain.
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
allowList := m.proxyURLAllowList()
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil {
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
if len(allowList) == 0 {
return "", fmt.Errorf("no proxy clusters available")
}

View File

@@ -1,539 +0,0 @@
package manager
import (
"context"
"fmt"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
const unknownHostPlaceholder = "unknown"
// ClusterDeriver derives the proxy cluster from a domain.
type ClusterDeriver interface {
DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error)
}
type managerImpl struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
proxyGRPCServer *nbgrpc.ProxyServiceServer
clusterDeriver ClusterDeriver
}
// NewManager creates a new service manager.
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager {
return &managerImpl{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
proxyGRPCServer: proxyGRPCServer,
clusterDeriver: clusterDeriver,
}
}
func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
}
return services, nil
}
func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error {
for _, target := range service.Targets {
switch target.TargetType {
case reverseproxy.TargetTypePeer:
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = peer.IP.String()
case reverseproxy.TargetTypeHost:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = resource.Prefix.Addr().String()
case reverseproxy.TargetTypeDomain:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = resource.Domain
case reverseproxy.TargetTypeSubnet:
// For subnets we do not do any lookups on the resource
default:
return fmt.Errorf("unknown target type: %s", target.TargetType)
}
}
return nil
}
func (m *managerImpl) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return nil, fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
return service, nil
}
func (m *managerImpl) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil {
return nil, err
}
if err := m.persistNewService(ctx, accountID, service); err != nil {
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta())
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil
}
func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID string, service *reverseproxy.Service) error {
if m.clusterDeriver != nil {
proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain)
return status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
}
service.ProxyCluster = proxyCluster
}
service.AccountID = accountID
service.InitNewRecord()
if err := service.Auth.HashSecrets(); err != nil {
return fmt.Errorf("hash secrets: %w", err)
}
keyPair, err := sessionkey.GenerateKeyPair()
if err != nil {
return fmt.Errorf("generate session keys: %w", err)
}
service.SessionPrivateKey = keyPair.PrivateKey
service.SessionPublicKey = keyPair.PublicKey
return nil
}
func (m *managerImpl) persistNewService(ctx context.Context, accountID string, service *reverseproxy.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, service); err != nil {
return fmt.Errorf("failed to create service: %w", err)
}
return nil
})
}
func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing service: %w", err)
}
return nil
}
if existingService != nil && existingService.ID != excludeServiceID {
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", domain)
}
return nil
}
func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err := service.Auth.HashSecrets(); err != nil {
return nil, fmt.Errorf("hash secrets: %w", err)
}
updateInfo, err := m.persistServiceUpdate(ctx, accountID, service)
if err != nil {
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceUpdated, service.EventMeta())
if err := m.replaceHostByLookup(ctx, accountID, service); err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.sendServiceUpdateNotifications(service, updateInfo)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil
}
type serviceUpdateInfo struct {
oldCluster string
domainChanged bool
serviceEnabledChanged bool
}
func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string, service *reverseproxy.Service) (*serviceUpdateInfo, error) {
var updateInfo serviceUpdateInfo
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
if err != nil {
return err
}
updateInfo.oldCluster = existingService.ProxyCluster
updateInfo.domainChanged = existingService.Domain != service.Domain
if updateInfo.domainChanged {
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
return err
}
} else {
service.ProxyCluster = existingService.ProxyCluster
}
m.preserveExistingAuthSecrets(service, existingService)
m.preserveServiceMetadata(service, existingService)
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
return nil
})
return &updateInfo, err
}
func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *reverseproxy.Service) error {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
return err
}
if m.clusterDeriver != nil {
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
} else {
service.ProxyCluster = newCluster
}
}
return nil
}
func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reverseproxy.Service) {
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
service.Auth.PasswordAuth.Password == "" {
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
}
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
service.Auth.PinAuth.Pin == "" {
service.Auth.PinAuth = existingService.Auth.PinAuth
}
}
func (m *managerImpl) preserveServiceMetadata(service, existingService *reverseproxy.Service) {
service.Meta = existingService.Meta
service.SessionPrivateKey = existingService.SessionPrivateKey
service.SessionPublicKey = existingService.SessionPublicKey
}
func (m *managerImpl) sendServiceUpdateNotifications(service *reverseproxy.Service, updateInfo *serviceUpdateInfo) {
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
switch {
case updateInfo.domainChanged && updateInfo.oldCluster != service.ProxyCluster:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), updateInfo.oldCluster)
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
case !service.Enabled && updateInfo.serviceEnabledChanged:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), service.ProxyCluster)
case service.Enabled && updateInfo.serviceEnabledChanged:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
default:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", oidcCfg), service.ProxyCluster)
}
}
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*reverseproxy.Target) error {
for _, target := range targets {
switch target.TargetType {
case reverseproxy.TargetTypePeer:
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
}
case reverseproxy.TargetTypeHost, reverseproxy.TargetTypeSubnet, reverseproxy.TargetTypeDomain:
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
}
}
}
return nil
}
func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
var service *reverseproxy.Service
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return err
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("failed to delete service: %w", err)
}
return nil
})
if err != nil {
return err
}
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta())
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
// Call this when receiving a gRPC notification that the certificate was issued.
func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get service: %w", err)
}
service.Meta.CertificateIssuedAt = time.Now()
if err = transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("failed to update service certificate timestamp: %w", err)
}
return nil
})
}
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get service: %w", err)
}
service.Meta.Status = string(status)
if err = transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("failed to update service status: %w", err)
}
return nil
})
}
func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID string) error {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
}
return nil
}
func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, service.AccountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
}
return services, nil
}
func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return nil, fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
return service, nil
}
func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
}
return services, nil
}
func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
return "", nil
}
return "", fmt.Errorf("failed to get service target by resource ID: %w", err)
}
if target == nil {
return "", nil
}
return target.ServiceID, nil
}

View File

@@ -1,375 +0,0 @@
package manager
import (
"context"
"errors"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
func TestInitializeServiceForCreate(t *testing.T) {
ctx := context.Background()
accountID := "test-account"
t.Run("successful initialization without cluster deriver", func(t *testing.T) {
mgr := &managerImpl{
clusterDeriver: nil,
}
service := &reverseproxy.Service{
Domain: "example.com",
Auth: reverseproxy.AuthConfig{},
}
err := mgr.initializeServiceForCreate(ctx, accountID, service)
assert.NoError(t, err)
assert.Equal(t, accountID, service.AccountID)
assert.Empty(t, service.ProxyCluster, "proxy cluster should be empty when no deriver")
assert.NotEmpty(t, service.ID, "service ID should be initialized")
assert.NotEmpty(t, service.SessionPrivateKey, "session private key should be generated")
assert.NotEmpty(t, service.SessionPublicKey, "session public key should be generated")
})
t.Run("verifies session keys are different", func(t *testing.T) {
mgr := &managerImpl{
clusterDeriver: nil,
}
service1 := &reverseproxy.Service{Domain: "test1.com", Auth: reverseproxy.AuthConfig{}}
service2 := &reverseproxy.Service{Domain: "test2.com", Auth: reverseproxy.AuthConfig{}}
err1 := mgr.initializeServiceForCreate(ctx, accountID, service1)
err2 := mgr.initializeServiceForCreate(ctx, accountID, service2)
assert.NoError(t, err1)
assert.NoError(t, err2)
assert.NotEqual(t, service1.SessionPrivateKey, service2.SessionPrivateKey, "private keys should be unique")
assert.NotEqual(t, service1.SessionPublicKey, service2.SessionPublicKey, "public keys should be unique")
})
}
func TestCheckDomainAvailable(t *testing.T) {
ctx := context.Background()
accountID := "test-account"
tests := []struct {
name string
domain string
excludeServiceID string
setupMock func(*store.MockStore)
expectedError bool
errorType status.Type
}{
{
name: "domain available - not found",
domain: "available.com",
excludeServiceID: "",
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "available.com").
Return(nil, status.Errorf(status.NotFound, "not found"))
},
expectedError: false,
},
{
name: "domain already exists",
domain: "exists.com",
excludeServiceID: "",
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com").
Return(&reverseproxy.Service{ID: "existing-id", Domain: "exists.com"}, nil)
},
expectedError: true,
errorType: status.AlreadyExists,
},
{
name: "domain exists but excluded (same ID)",
domain: "exists.com",
excludeServiceID: "service-123",
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com").
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
},
expectedError: false,
},
{
name: "domain exists with different ID",
domain: "exists.com",
excludeServiceID: "service-456",
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com").
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
},
expectedError: true,
errorType: status.AlreadyExists,
},
{
name: "store error (non-NotFound)",
domain: "error.com",
excludeServiceID: "",
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "error.com").
Return(nil, errors.New("database error"))
},
expectedError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
tt.setupMock(mockStore)
mgr := &managerImpl{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID)
if tt.expectedError {
require.Error(t, err)
if tt.errorType != 0 {
sErr, ok := status.FromError(err)
require.True(t, ok, "error should be a status error")
assert.Equal(t, tt.errorType, sErr.Type())
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
ctx := context.Background()
accountID := "test-account"
t.Run("empty domain", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetServiceByDomain(ctx, accountID, "").
Return(nil, status.Errorf(status.NotFound, "not found"))
mgr := &managerImpl{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "")
assert.NoError(t, err)
})
t.Run("empty exclude ID with existing service", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetServiceByDomain(ctx, accountID, "test.com").
Return(&reverseproxy.Service{ID: "some-id", Domain: "test.com"}, nil)
mgr := &managerImpl{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "")
assert.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
assert.Equal(t, status.AlreadyExists, sErr.Type())
})
t.Run("nil existing service with nil error", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetServiceByDomain(ctx, accountID, "nil.com").
Return(nil, nil)
mgr := &managerImpl{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "")
assert.NoError(t, err)
})
}
func TestPersistNewService(t *testing.T) {
ctx := context.Background()
accountID := "test-account"
t.Run("successful service creation with no targets", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
service := &reverseproxy.Service{
ID: "service-123",
Domain: "new.com",
Targets: []*reverseproxy.Target{},
}
// Mock ExecuteInTransaction to execute the function immediately
mockStore.EXPECT().
ExecuteInTransaction(ctx, gomock.Any()).
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
// Create another mock for the transaction
txMock := store.NewMockStore(ctrl)
txMock.EXPECT().
GetServiceByDomain(ctx, accountID, "new.com").
Return(nil, status.Errorf(status.NotFound, "not found"))
txMock.EXPECT().
CreateService(ctx, service).
Return(nil)
return fn(txMock)
})
mgr := &managerImpl{store: mockStore}
err := mgr.persistNewService(ctx, accountID, service)
assert.NoError(t, err)
})
t.Run("domain already exists", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
service := &reverseproxy.Service{
ID: "service-123",
Domain: "existing.com",
Targets: []*reverseproxy.Target{},
}
mockStore.EXPECT().
ExecuteInTransaction(ctx, gomock.Any()).
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
txMock := store.NewMockStore(ctrl)
txMock.EXPECT().
GetServiceByDomain(ctx, accountID, "existing.com").
Return(&reverseproxy.Service{ID: "other-id", Domain: "existing.com"}, nil)
return fn(txMock)
})
mgr := &managerImpl{store: mockStore}
err := mgr.persistNewService(ctx, accountID, service)
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
assert.Equal(t, status.AlreadyExists, sErr.Type())
})
}
func TestPreserveExistingAuthSecrets(t *testing.T) {
mgr := &managerImpl{}
t.Run("preserve password when empty", func(t *testing.T) {
existing := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{
Enabled: true,
Password: "hashed-password",
},
},
}
updated := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{
Enabled: true,
Password: "",
},
},
}
mgr.preserveExistingAuthSecrets(updated, existing)
assert.Equal(t, existing.Auth.PasswordAuth, updated.Auth.PasswordAuth)
})
t.Run("preserve pin when empty", func(t *testing.T) {
existing := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PinAuth: &reverseproxy.PINAuthConfig{
Enabled: true,
Pin: "hashed-pin",
},
},
}
updated := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PinAuth: &reverseproxy.PINAuthConfig{
Enabled: true,
Pin: "",
},
},
}
mgr.preserveExistingAuthSecrets(updated, existing)
assert.Equal(t, existing.Auth.PinAuth, updated.Auth.PinAuth)
})
t.Run("do not preserve when password is provided", func(t *testing.T) {
existing := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{
Enabled: true,
Password: "old-password",
},
},
}
updated := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{
Enabled: true,
Password: "new-password",
},
},
}
mgr.preserveExistingAuthSecrets(updated, existing)
assert.Equal(t, "new-password", updated.Auth.PasswordAuth.Password)
assert.NotEqual(t, existing.Auth.PasswordAuth, updated.Auth.PasswordAuth)
})
}
func TestPreserveServiceMetadata(t *testing.T) {
mgr := &managerImpl{}
existing := &reverseproxy.Service{
Meta: reverseproxy.ServiceMeta{
CertificateIssuedAt: time.Now(),
Status: "active",
},
SessionPrivateKey: "private-key",
SessionPublicKey: "public-key",
}
updated := &reverseproxy.Service{
Domain: "updated.com",
}
mgr.preserveServiceMetadata(updated, existing)
assert.Equal(t, existing.Meta, updated.Meta)
assert.Equal(t, existing.SessionPrivateKey, updated.SessionPrivateKey)
assert.Equal(t, existing.SessionPublicKey, updated.SessionPublicKey)
}

View File

@@ -0,0 +1,36 @@
package proxy
//go:generate go run github.com/golang/mock/mockgen -package proxy -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
import (
"context"
"time"
"github.com/netbirdio/netbird/shared/management/proto"
)
// Manager defines the interface for proxy operations
type Manager interface {
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
Disconnect(ctx context.Context, proxyID string) error
Heartbeat(ctx context.Context, proxyID string) error
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
}
// OIDCValidationConfig contains the OIDC configuration needed for token validation.
type OIDCValidationConfig struct {
Issuer string
Audiences []string
KeysLocation string
MaxTokenAgeSeconds int64
}
// Controller is responsible for managing proxy clusters and routing service updates.
type Controller interface {
SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string)
GetOIDCValidationConfig() OIDCValidationConfig
RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error
UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error
GetProxiesForCluster(clusterAddr string) []string
}

View File

@@ -0,0 +1,88 @@
package manager
import (
"context"
"sync"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/shared/management/proto"
)
// GRPCController is a concrete implementation that manages proxy clusters and sends updates directly via gRPC.
type GRPCController struct {
proxyGRPCServer *nbgrpc.ProxyServiceServer
// Map of cluster address -> set of proxy IDs
clusterProxies sync.Map
metrics *metrics
}
// NewGRPCController creates a new GRPCController.
func NewGRPCController(proxyGRPCServer *nbgrpc.ProxyServiceServer, meter metric.Meter) (*GRPCController, error) {
m, err := newMetrics(meter)
if err != nil {
return nil, err
}
return &GRPCController{
proxyGRPCServer: proxyGRPCServer,
metrics: m,
}, nil
}
// SendServiceUpdateToCluster sends a service update to a specific proxy cluster.
func (c *GRPCController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) {
c.proxyGRPCServer.SendServiceUpdateToCluster(ctx, update, clusterAddr)
c.metrics.IncrementServiceUpdateSendCount(clusterAddr)
}
// GetOIDCValidationConfig returns the OIDC validation configuration from the gRPC server.
func (c *GRPCController) GetOIDCValidationConfig() proxy.OIDCValidationConfig {
return c.proxyGRPCServer.GetOIDCValidationConfig()
}
// RegisterProxyToCluster registers a proxy to a specific cluster for routing.
func (c *GRPCController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
if clusterAddr == "" {
return nil
}
proxySet, _ := c.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
proxySet.(*sync.Map).Store(proxyID, struct{}{})
log.WithContext(ctx).Debugf("Registered proxy %s to cluster %s", proxyID, clusterAddr)
c.metrics.IncrementProxyConnectionCount(clusterAddr)
return nil
}
// UnregisterProxyFromCluster removes a proxy from a cluster.
func (c *GRPCController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
if clusterAddr == "" {
return nil
}
if proxySet, ok := c.clusterProxies.Load(clusterAddr); ok {
proxySet.(*sync.Map).Delete(proxyID)
log.WithContext(ctx).Debugf("Unregistered proxy %s from cluster %s", proxyID, clusterAddr)
c.metrics.DecrementProxyConnectionCount(clusterAddr)
}
return nil
}
// GetProxiesForCluster returns all proxy IDs registered for a specific cluster.
func (c *GRPCController) GetProxiesForCluster(clusterAddr string) []string {
proxySet, ok := c.clusterProxies.Load(clusterAddr)
if !ok {
return nil
}
var proxies []string
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
proxies = append(proxies, key.(string))
return true
})
return proxies
}

View File

@@ -0,0 +1,115 @@
package manager
import (
"context"
"time"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
)
// store defines the interface for proxy persistence operations
type store interface {
SaveProxy(ctx context.Context, p *proxy.Proxy) error
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
}
// Manager handles all proxy operations
type Manager struct {
store store
metrics *metrics
}
// NewManager creates a new proxy Manager
func NewManager(store store, meter metric.Meter) (*Manager, error) {
m, err := newMetrics(meter)
if err != nil {
return nil, err
}
return &Manager{
store: store,
metrics: m,
}, nil
}
// Connect registers a new proxy connection in the database
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
now := time.Now()
p := &proxy.Proxy{
ID: proxyID,
ClusterAddress: clusterAddress,
IPAddress: ipAddress,
LastSeen: now,
ConnectedAt: &now,
Status: "connected",
}
if err := m.store.SaveProxy(ctx, p); err != nil {
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
return err
}
log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID,
"clusterAddress": clusterAddress,
"ipAddress": ipAddress,
}).Info("proxy connected")
return nil
}
// Disconnect marks a proxy as disconnected in the database
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
now := time.Now()
p := &proxy.Proxy{
ID: proxyID,
Status: "disconnected",
DisconnectedAt: &now,
LastSeen: now,
}
if err := m.store.SaveProxy(ctx, p); err != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
return err
}
log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID,
}).Info("proxy disconnected")
return nil
}
// Heartbeat updates the proxy's last seen timestamp
func (m Manager) Heartbeat(ctx context.Context, proxyID string) error {
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID); err != nil {
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
return err
}
m.metrics.IncrementProxyHeartbeatCount()
return nil
}
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
return nil, err
}
return addresses, nil
}
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
return err
}
return nil
}

View File

@@ -0,0 +1,74 @@
package manager
import (
"context"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
type metrics struct {
proxyConnectionCount metric.Int64UpDownCounter
serviceUpdateSendCount metric.Int64Counter
proxyHeartbeatCount metric.Int64Counter
}
func newMetrics(meter metric.Meter) (*metrics, error) {
proxyConnectionCount, err := meter.Int64UpDownCounter(
"management_proxy_connection_count",
metric.WithDescription("Number of active proxy connections"),
metric.WithUnit("{connection}"),
)
if err != nil {
return nil, err
}
serviceUpdateSendCount, err := meter.Int64Counter(
"management_proxy_service_update_send_count",
metric.WithDescription("Total number of service updates sent to proxies"),
metric.WithUnit("{update}"),
)
if err != nil {
return nil, err
}
proxyHeartbeatCount, err := meter.Int64Counter(
"management_proxy_heartbeat_count",
metric.WithDescription("Total number of proxy heartbeats received"),
metric.WithUnit("{heartbeat}"),
)
if err != nil {
return nil, err
}
return &metrics{
proxyConnectionCount: proxyConnectionCount,
serviceUpdateSendCount: serviceUpdateSendCount,
proxyHeartbeatCount: proxyHeartbeatCount,
}, nil
}
func (m *metrics) IncrementProxyConnectionCount(clusterAddr string) {
m.proxyConnectionCount.Add(context.Background(), 1,
metric.WithAttributes(
attribute.String("cluster", clusterAddr),
))
}
func (m *metrics) DecrementProxyConnectionCount(clusterAddr string) {
m.proxyConnectionCount.Add(context.Background(), -1,
metric.WithAttributes(
attribute.String("cluster", clusterAddr),
))
}
func (m *metrics) IncrementServiceUpdateSendCount(clusterAddr string) {
m.serviceUpdateSendCount.Add(context.Background(), 1,
metric.WithAttributes(
attribute.String("cluster", clusterAddr),
))
}
func (m *metrics) IncrementProxyHeartbeatCount() {
m.proxyHeartbeatCount.Add(context.Background(), 1)
}

View File

@@ -0,0 +1,199 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ./manager.go
// Package proxy is a generated GoMock package.
package proxy
import (
context "context"
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
proto "github.com/netbirdio/netbird/shared/management/proto"
)
// MockManager is a mock of Manager interface.
type MockManager struct {
ctrl *gomock.Controller
recorder *MockManagerMockRecorder
}
// MockManagerMockRecorder is the mock recorder for MockManager.
type MockManagerMockRecorder struct {
mock *MockManager
}
// NewMockManager creates a new mock instance.
func NewMockManager(ctrl *gomock.Controller) *MockManager {
mock := &MockManager{ctrl: ctrl}
mock.recorder = &MockManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockManager) EXPECT() *MockManagerMockRecorder {
return m.recorder
}
// CleanupStale mocks base method.
func (m *MockManager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CleanupStale", ctx, inactivityDuration)
ret0, _ := ret[0].(error)
return ret0
}
// CleanupStale indicates an expected call of CleanupStale.
func (mr *MockManagerMockRecorder) CleanupStale(ctx, inactivityDuration interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStale", reflect.TypeOf((*MockManager)(nil).CleanupStale), ctx, inactivityDuration)
}
// Connect mocks base method.
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress)
ret0, _ := ret[0].(error)
return ret0
}
// Connect indicates an expected call of Connect.
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress)
}
// Disconnect mocks base method.
func (m *MockManager) Disconnect(ctx context.Context, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// Disconnect indicates an expected call of Disconnect.
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID)
}
// GetActiveClusterAddresses mocks base method.
func (m *MockManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveClusterAddresses", ctx)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveClusterAddresses indicates an expected call of GetActiveClusterAddresses.
func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx)
}
// Heartbeat mocks base method.
func (m *MockManager) Heartbeat(ctx context.Context, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// Heartbeat indicates an expected call of Heartbeat.
func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID)
}
// MockController is a mock of Controller interface.
type MockController struct {
ctrl *gomock.Controller
recorder *MockControllerMockRecorder
}
// MockControllerMockRecorder is the mock recorder for MockController.
type MockControllerMockRecorder struct {
mock *MockController
}
// NewMockController creates a new mock instance.
func NewMockController(ctrl *gomock.Controller) *MockController {
mock := &MockController{ctrl: ctrl}
mock.recorder = &MockControllerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockController) EXPECT() *MockControllerMockRecorder {
return m.recorder
}
// GetOIDCValidationConfig mocks base method.
func (m *MockController) GetOIDCValidationConfig() OIDCValidationConfig {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOIDCValidationConfig")
ret0, _ := ret[0].(OIDCValidationConfig)
return ret0
}
// GetOIDCValidationConfig indicates an expected call of GetOIDCValidationConfig.
func (mr *MockControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockController)(nil).GetOIDCValidationConfig))
}
// GetProxiesForCluster mocks base method.
func (m *MockController) GetProxiesForCluster(clusterAddr string) []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProxiesForCluster", clusterAddr)
ret0, _ := ret[0].([]string)
return ret0
}
// GetProxiesForCluster indicates an expected call of GetProxiesForCluster.
func (mr *MockControllerMockRecorder) GetProxiesForCluster(clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxiesForCluster", reflect.TypeOf((*MockController)(nil).GetProxiesForCluster), clusterAddr)
}
// RegisterProxyToCluster mocks base method.
func (m *MockController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RegisterProxyToCluster", ctx, clusterAddr, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// RegisterProxyToCluster indicates an expected call of RegisterProxyToCluster.
func (mr *MockControllerMockRecorder) RegisterProxyToCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterProxyToCluster", reflect.TypeOf((*MockController)(nil).RegisterProxyToCluster), ctx, clusterAddr, proxyID)
}
// SendServiceUpdateToCluster mocks base method.
func (m *MockController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SendServiceUpdateToCluster", ctx, accountID, update, clusterAddr)
}
// SendServiceUpdateToCluster indicates an expected call of SendServiceUpdateToCluster.
func (mr *MockControllerMockRecorder) SendServiceUpdateToCluster(ctx, accountID, update, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendServiceUpdateToCluster", reflect.TypeOf((*MockController)(nil).SendServiceUpdateToCluster), ctx, accountID, update, clusterAddr)
}
// UnregisterProxyFromCluster mocks base method.
func (m *MockController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnregisterProxyFromCluster", ctx, clusterAddr, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// UnregisterProxyFromCluster indicates an expected call of UnregisterProxyFromCluster.
func (mr *MockControllerMockRecorder) UnregisterProxyFromCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterProxyFromCluster", reflect.TypeOf((*MockController)(nil).UnregisterProxyFromCluster), ctx, clusterAddr, proxyID)
}

View File

@@ -0,0 +1,20 @@
package proxy
import "time"
// Proxy represents a reverse proxy instance
type Proxy struct {
ID string `gorm:"primaryKey;type:varchar(255)"`
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
IPAddress string `gorm:"type:varchar(45)"`
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
ConnectedAt *time.Time
DisconnectedAt *time.Time
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
CreatedAt time.Time
UpdatedAt time.Time
}
func (Proxy) TableName() string {
return "proxies"
}

View File

@@ -1,463 +0,0 @@
package reverseproxy
import (
"errors"
"fmt"
"net"
"net/url"
"strconv"
"time"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/util/crypt"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/proto"
)
type Operation string
const (
Create Operation = "create"
Update Operation = "update"
Delete Operation = "delete"
)
type ProxyStatus string
const (
StatusPending ProxyStatus = "pending"
StatusActive ProxyStatus = "active"
StatusTunnelNotCreated ProxyStatus = "tunnel_not_created"
StatusCertificatePending ProxyStatus = "certificate_pending"
StatusCertificateFailed ProxyStatus = "certificate_failed"
StatusError ProxyStatus = "error"
TargetTypePeer = "peer"
TargetTypeHost = "host"
TargetTypeDomain = "domain"
TargetTypeSubnet = "subnet"
)
type Target struct {
ID uint `gorm:"primaryKey" json:"-"`
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
Path *string `json:"path,omitempty"`
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
Port int `gorm:"index:idx_target_port" json:"port"`
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
}
type PasswordAuthConfig struct {
Enabled bool `json:"enabled"`
Password string `json:"password"`
}
type PINAuthConfig struct {
Enabled bool `json:"enabled"`
Pin string `json:"pin"`
}
type BearerAuthConfig struct {
Enabled bool `json:"enabled"`
DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"`
}
type AuthConfig struct {
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"`
PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"`
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
}
func (a *AuthConfig) HashSecrets() error {
if a.PasswordAuth != nil && a.PasswordAuth.Enabled && a.PasswordAuth.Password != "" {
hashedPassword, err := argon2id.Hash(a.PasswordAuth.Password)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
a.PasswordAuth.Password = hashedPassword
}
if a.PinAuth != nil && a.PinAuth.Enabled && a.PinAuth.Pin != "" {
hashedPin, err := argon2id.Hash(a.PinAuth.Pin)
if err != nil {
return fmt.Errorf("hash pin: %w", err)
}
a.PinAuth.Pin = hashedPin
}
return nil
}
func (a *AuthConfig) ClearSecrets() {
if a.PasswordAuth != nil {
a.PasswordAuth.Password = ""
}
if a.PinAuth != nil {
a.PinAuth.Pin = ""
}
}
type OIDCValidationConfig struct {
Issuer string
Audiences []string
KeysLocation string
MaxTokenAgeSeconds int64
}
type ServiceMeta struct {
CreatedAt time.Time
CertificateIssuedAt time.Time
Status string
}
type Service struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
Name string
Domain string `gorm:"index"`
ProxyCluster string `gorm:"index"`
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
Enabled bool
PassHostHeader bool
RewriteRedirects bool
Auth AuthConfig `gorm:"serializer:json"`
Meta ServiceMeta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"`
SessionPublicKey string `gorm:"column:session_public_key"`
}
func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service {
for _, target := range targets {
target.AccountID = accountID
}
s := &Service{
AccountID: accountID,
Name: name,
Domain: domain,
ProxyCluster: proxyCluster,
Targets: targets,
Enabled: enabled,
}
s.InitNewRecord()
return s
}
// InitNewRecord generates a new unique ID and resets metadata for a newly created
// Service record. This overwrites any existing ID and Meta fields and should
// only be called during initial creation, not for updates.
func (s *Service) InitNewRecord() {
s.ID = xid.New().String()
s.Meta = ServiceMeta{
CreatedAt: time.Now(),
Status: string(StatusPending),
}
}
func (s *Service) ToAPIResponse() *api.Service {
s.Auth.ClearSecrets()
authConfig := api.ServiceAuthConfig{}
if s.Auth.PasswordAuth != nil {
authConfig.PasswordAuth = &api.PasswordAuthConfig{
Enabled: s.Auth.PasswordAuth.Enabled,
Password: s.Auth.PasswordAuth.Password,
}
}
if s.Auth.PinAuth != nil {
authConfig.PinAuth = &api.PINAuthConfig{
Enabled: s.Auth.PinAuth.Enabled,
Pin: s.Auth.PinAuth.Pin,
}
}
if s.Auth.BearerAuth != nil {
authConfig.BearerAuth = &api.BearerAuthConfig{
Enabled: s.Auth.BearerAuth.Enabled,
DistributionGroups: &s.Auth.BearerAuth.DistributionGroups,
}
}
// Convert internal targets to API targets
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
for _, target := range s.Targets {
apiTargets = append(apiTargets, api.ServiceTarget{
Path: target.Path,
Host: &target.Host,
Port: target.Port,
Protocol: api.ServiceTargetProtocol(target.Protocol),
TargetId: target.TargetId,
TargetType: api.ServiceTargetTargetType(target.TargetType),
Enabled: target.Enabled,
})
}
meta := api.ServiceMeta{
CreatedAt: s.Meta.CreatedAt,
Status: api.ServiceMetaStatus(s.Meta.Status),
}
if !s.Meta.CertificateIssuedAt.IsZero() {
meta.CertificateIssuedAt = &s.Meta.CertificateIssuedAt
}
resp := &api.Service{
Id: s.ID,
Name: s.Name,
Domain: s.Domain,
Targets: apiTargets,
Enabled: s.Enabled,
PassHostHeader: &s.PassHostHeader,
RewriteRedirects: &s.RewriteRedirects,
Auth: authConfig,
Meta: meta,
}
if s.ProxyCluster != "" {
resp.ProxyCluster = &s.ProxyCluster
}
return resp
}
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping {
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
for _, target := range s.Targets {
if !target.Enabled {
continue
}
// TODO: Make path prefix stripping configurable per-target.
// Currently the matching prefix is baked into the target URL path,
// so the proxy strips-then-re-adds it (effectively a no-op).
targetURL := url.URL{
Scheme: target.Protocol,
Host: target.Host,
Path: "/", // TODO: support service path
}
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port))
}
path := "/"
if target.Path != nil {
path = *target.Path
}
pathMappings = append(pathMappings, &proto.PathMapping{
Path: path,
Target: targetURL.String(),
})
}
auth := &proto.Authentication{
SessionKey: s.SessionPublicKey,
MaxSessionAgeSeconds: int64((time.Hour * 24).Seconds()),
}
if s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled {
auth.Password = true
}
if s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled {
auth.Pin = true
}
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
auth.Oidc = true
}
return &proto.ProxyMapping{
Type: operationToProtoType(operation),
Id: s.ID,
Domain: s.Domain,
Path: pathMappings,
AuthToken: authToken,
Auth: auth,
AccountId: s.AccountID,
PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects,
}
}
func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
switch op {
case Create:
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
case Update:
return proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED
case Delete:
return proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED
default:
log.Fatalf("unknown operation type: %v", op)
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
}
}
// isDefaultPort reports whether port is the standard default for the given scheme
// (443 for https, 80 for http).
func isDefaultPort(scheme string, port int) bool {
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
}
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
s.Name = req.Name
s.Domain = req.Domain
s.AccountID = accountID
targets := make([]*Target, 0, len(req.Targets))
for _, apiTarget := range req.Targets {
target := &Target{
AccountID: accountID,
Path: apiTarget.Path,
Port: apiTarget.Port,
Protocol: string(apiTarget.Protocol),
TargetId: apiTarget.TargetId,
TargetType: string(apiTarget.TargetType),
Enabled: apiTarget.Enabled,
}
if apiTarget.Host != nil {
target.Host = *apiTarget.Host
}
targets = append(targets, target)
}
s.Targets = targets
s.Enabled = req.Enabled
if req.PassHostHeader != nil {
s.PassHostHeader = *req.PassHostHeader
}
if req.RewriteRedirects != nil {
s.RewriteRedirects = *req.RewriteRedirects
}
if req.Auth.PasswordAuth != nil {
s.Auth.PasswordAuth = &PasswordAuthConfig{
Enabled: req.Auth.PasswordAuth.Enabled,
Password: req.Auth.PasswordAuth.Password,
}
}
if req.Auth.PinAuth != nil {
s.Auth.PinAuth = &PINAuthConfig{
Enabled: req.Auth.PinAuth.Enabled,
Pin: req.Auth.PinAuth.Pin,
}
}
if req.Auth.BearerAuth != nil {
bearerAuth := &BearerAuthConfig{
Enabled: req.Auth.BearerAuth.Enabled,
}
if req.Auth.BearerAuth.DistributionGroups != nil {
bearerAuth.DistributionGroups = *req.Auth.BearerAuth.DistributionGroups
}
s.Auth.BearerAuth = bearerAuth
}
}
func (s *Service) Validate() error {
if s.Name == "" {
return errors.New("service name is required")
}
if len(s.Name) > 255 {
return errors.New("service name exceeds maximum length of 255 characters")
}
if s.Domain == "" {
return errors.New("service domain is required")
}
if len(s.Targets) == 0 {
return errors.New("at least one target is required")
}
for i, target := range s.Targets {
switch target.TargetType {
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
// host field will be ignored
case TargetTypeSubnet:
if target.Host == "" {
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
}
default:
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
}
if target.TargetId == "" {
return fmt.Errorf("target %d has empty target_id", i)
}
}
return nil
}
func (s *Service) EventMeta() map[string]any {
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster}
}
func (s *Service) Copy() *Service {
targets := make([]*Target, len(s.Targets))
for i, target := range s.Targets {
targetCopy := *target
targets[i] = &targetCopy
}
return &Service{
ID: s.ID,
AccountID: s.AccountID,
Name: s.Name,
Domain: s.Domain,
ProxyCluster: s.ProxyCluster,
Targets: targets,
Enabled: s.Enabled,
PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects,
Auth: s.Auth,
Meta: s.Meta,
SessionPrivateKey: s.SessionPrivateKey,
SessionPublicKey: s.SessionPublicKey,
}
}
func (s *Service) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
if s.SessionPrivateKey != "" {
var err error
s.SessionPrivateKey, err = enc.Encrypt(s.SessionPrivateKey)
if err != nil {
return err
}
}
return nil
}
func (s *Service) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
if s.SessionPrivateKey != "" {
var err error
s.SessionPrivateKey, err = enc.Decrypt(s.SessionPrivateKey)
if err != nil {
return err
}
}
return nil
}

View File

@@ -1,405 +0,0 @@
package reverseproxy
import (
"errors"
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/shared/management/proto"
)
func validProxy() *Service {
return &Service{
Name: "test",
Domain: "example.com",
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 80, Protocol: "http", Enabled: true},
},
}
}
func TestValidate_Valid(t *testing.T) {
require.NoError(t, validProxy().Validate())
}
func TestValidate_EmptyName(t *testing.T) {
rp := validProxy()
rp.Name = ""
assert.ErrorContains(t, rp.Validate(), "name is required")
}
func TestValidate_EmptyDomain(t *testing.T) {
rp := validProxy()
rp.Domain = ""
assert.ErrorContains(t, rp.Validate(), "domain is required")
}
func TestValidate_NoTargets(t *testing.T) {
rp := validProxy()
rp.Targets = nil
assert.ErrorContains(t, rp.Validate(), "at least one target")
}
func TestValidate_EmptyTargetId(t *testing.T) {
rp := validProxy()
rp.Targets[0].TargetId = ""
assert.ErrorContains(t, rp.Validate(), "empty target_id")
}
func TestValidate_InvalidTargetType(t *testing.T) {
rp := validProxy()
rp.Targets[0].TargetType = "invalid"
assert.ErrorContains(t, rp.Validate(), "invalid target_type")
}
func TestValidate_ResourceTarget(t *testing.T) {
rp := validProxy()
rp.Targets = append(rp.Targets, &Target{
TargetId: "resource-1",
TargetType: TargetTypeHost,
Host: "example.org",
Port: 443,
Protocol: "https",
Enabled: true,
})
require.NoError(t, rp.Validate())
}
func TestValidate_MultipleTargetsOneInvalid(t *testing.T) {
rp := validProxy()
rp.Targets = append(rp.Targets, &Target{
TargetId: "",
TargetType: TargetTypePeer,
Host: "10.0.0.2",
Port: 80,
Protocol: "http",
Enabled: true,
})
err := rp.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "target 1")
assert.Contains(t, err.Error(), "empty target_id")
}
func TestIsDefaultPort(t *testing.T) {
tests := []struct {
scheme string
port int
want bool
}{
{"http", 80, true},
{"https", 443, true},
{"http", 443, false},
{"https", 80, false},
{"http", 8080, false},
{"https", 8443, false},
{"http", 0, false},
{"https", 0, false},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("%s/%d", tt.scheme, tt.port), func(t *testing.T) {
assert.Equal(t, tt.want, isDefaultPort(tt.scheme, tt.port))
})
}
}
func TestToProtoMapping_PortInTargetURL(t *testing.T) {
oidcConfig := OIDCValidationConfig{}
tests := []struct {
name string
protocol string
host string
port int
wantTarget string
}{
{
name: "http with default port 80 omits port",
protocol: "http",
host: "10.0.0.1",
port: 80,
wantTarget: "http://10.0.0.1/",
},
{
name: "https with default port 443 omits port",
protocol: "https",
host: "10.0.0.1",
port: 443,
wantTarget: "https://10.0.0.1/",
},
{
name: "port 0 omits port",
protocol: "http",
host: "10.0.0.1",
port: 0,
wantTarget: "http://10.0.0.1/",
},
{
name: "non-default port is included",
protocol: "http",
host: "10.0.0.1",
port: 8080,
wantTarget: "http://10.0.0.1:8080/",
},
{
name: "https with non-default port is included",
protocol: "https",
host: "10.0.0.1",
port: 8443,
wantTarget: "https://10.0.0.1:8443/",
},
{
name: "http port 443 is included",
protocol: "http",
host: "10.0.0.1",
port: 443,
wantTarget: "http://10.0.0.1:443/",
},
{
name: "https port 80 is included",
protocol: "https",
host: "10.0.0.1",
port: 80,
wantTarget: "https://10.0.0.1:80/",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rp := &Service{
ID: "test-id",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{
TargetId: "peer-1",
TargetType: TargetTypePeer,
Host: tt.host,
Port: tt.port,
Protocol: tt.protocol,
Enabled: true,
},
},
}
pm := rp.ToProtoMapping(Create, "token", oidcConfig)
require.Len(t, pm.Path, 1, "should have one path mapping")
assert.Equal(t, tt.wantTarget, pm.Path[0].Target)
})
}
}
func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) {
rp := &Service{
ID: "test-id",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 8080, Protocol: "http", Enabled: false},
{TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true},
},
}
pm := rp.ToProtoMapping(Create, "token", OIDCValidationConfig{})
require.Len(t, pm.Path, 1)
assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target)
}
func TestToProtoMapping_OperationTypes(t *testing.T) {
rp := validProxy()
tests := []struct {
op Operation
want proto.ProxyMappingUpdateType
}{
{Create, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED},
{Update, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED},
{Delete, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED},
}
for _, tt := range tests {
t.Run(string(tt.op), func(t *testing.T) {
pm := rp.ToProtoMapping(tt.op, "", OIDCValidationConfig{})
assert.Equal(t, tt.want, pm.Type)
})
}
}
func TestAuthConfig_HashSecrets(t *testing.T) {
tests := []struct {
name string
config *AuthConfig
wantErr bool
validate func(*testing.T, *AuthConfig)
}{
{
name: "hash password successfully",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "testPassword123",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
t.Errorf("Password not hashed with argon2id, got: %s", config.PasswordAuth.Password)
}
// Verify the hash can be verified
if err := argon2id.Verify("testPassword123", config.PasswordAuth.Password); err != nil {
t.Errorf("Hash verification failed: %v", err)
}
},
},
{
name: "hash PIN successfully",
config: &AuthConfig{
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "123456",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
t.Errorf("PIN not hashed with argon2id, got: %s", config.PinAuth.Pin)
}
// Verify the hash can be verified
if err := argon2id.Verify("123456", config.PinAuth.Pin); err != nil {
t.Errorf("Hash verification failed: %v", err)
}
},
},
{
name: "hash both password and PIN",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "password",
},
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "9999",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
t.Errorf("Password not hashed with argon2id")
}
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
t.Errorf("PIN not hashed with argon2id")
}
if err := argon2id.Verify("password", config.PasswordAuth.Password); err != nil {
t.Errorf("Password hash verification failed: %v", err)
}
if err := argon2id.Verify("9999", config.PinAuth.Pin); err != nil {
t.Errorf("PIN hash verification failed: %v", err)
}
},
},
{
name: "skip disabled password auth",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: false,
Password: "password",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if config.PasswordAuth.Password != "password" {
t.Errorf("Disabled password auth should not be hashed")
}
},
},
{
name: "skip empty password",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if config.PasswordAuth.Password != "" {
t.Errorf("Empty password should remain empty")
}
},
},
{
name: "skip nil password auth",
config: &AuthConfig{
PasswordAuth: nil,
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "1234",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if config.PasswordAuth != nil {
t.Errorf("PasswordAuth should remain nil")
}
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
t.Errorf("PIN should still be hashed")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.HashSecrets()
if (err != nil) != tt.wantErr {
t.Errorf("HashSecrets() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.validate != nil {
tt.validate(t, tt.config)
}
})
}
}
func TestAuthConfig_HashSecrets_VerifyIncorrectSecret(t *testing.T) {
config := &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "correctPassword",
},
}
if err := config.HashSecrets(); err != nil {
t.Fatalf("HashSecrets() error = %v", err)
}
// Verify with wrong password should fail
err := argon2id.Verify("wrongPassword", config.PasswordAuth.Password)
if !errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
t.Errorf("Expected ErrMismatchedHashAndPassword, got %v", err)
}
}
func TestAuthConfig_ClearSecrets(t *testing.T) {
config := &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "hashedPassword",
},
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "hashedPin",
},
}
config.ClearSecrets()
if config.PasswordAuth.Password != "" {
t.Errorf("Password not cleared, got: %s", config.PasswordAuth.Password)
}
if config.PinAuth.Pin != "" {
t.Errorf("PIN not cleared, got: %s", config.PinAuth.Pin)
}
}

View File

@@ -1,6 +1,6 @@
package reverseproxy
package service
//go:generate go run github.com/golang/mock/mockgen -package reverseproxy -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
//go:generate go run github.com/golang/mock/mockgen -package service -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
import (
"context"
@@ -12,12 +12,17 @@ type Manager interface {
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
DeleteService(ctx context.Context, accountID, userID, serviceID string) error
DeleteAllServices(ctx context.Context, accountID, userID string) error
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error
SetStatus(ctx context.Context, accountID, serviceID string, status Status) error
ReloadAllServicesForAccount(ctx context.Context, accountID string) error
ReloadService(ctx context.Context, accountID, serviceID string) error
GetGlobalServices(ctx context.Context) ([]*Service, error)
GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error)
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error)
RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
StartExposeReaper(ctx context.Context)
}

View File

@@ -1,8 +1,8 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ./interface.go
// Package reverseproxy is a generated GoMock package.
package reverseproxy
// Package service is a generated GoMock package.
package service
import (
context "context"
@@ -49,6 +49,35 @@ func (mr *MockManagerMockRecorder) CreateService(ctx, accountID, userID, service
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service)
}
// CreateServiceFromPeer mocks base method.
func (m *MockManager) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateServiceFromPeer", ctx, accountID, peerID, req)
ret0, _ := ret[0].(*ExposeServiceResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateServiceFromPeer indicates an expected call of CreateServiceFromPeer.
func (mr *MockManagerMockRecorder) CreateServiceFromPeer(ctx, accountID, peerID, req interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateServiceFromPeer", reflect.TypeOf((*MockManager)(nil).CreateServiceFromPeer), ctx, accountID, peerID, req)
}
// DeleteAllServices mocks base method.
func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAllServices", ctx, accountID, userID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAllServices indicates an expected call of DeleteAllServices.
func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
}
// DeleteService mocks base method.
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
m.ctrl.T.Helper()
@@ -181,6 +210,20 @@ func (mr *MockManagerMockRecorder) ReloadService(ctx, accountID, serviceID inter
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadService", reflect.TypeOf((*MockManager)(nil).ReloadService), ctx, accountID, serviceID)
}
// RenewServiceFromPeer mocks base method.
func (m *MockManager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RenewServiceFromPeer", ctx, accountID, peerID, domain)
ret0, _ := ret[0].(error)
return ret0
}
// RenewServiceFromPeer indicates an expected call of RenewServiceFromPeer.
func (mr *MockManagerMockRecorder) RenewServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewServiceFromPeer", reflect.TypeOf((*MockManager)(nil).RenewServiceFromPeer), ctx, accountID, peerID, domain)
}
// SetCertificateIssuedAt mocks base method.
func (m *MockManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
m.ctrl.T.Helper()
@@ -196,7 +239,7 @@ func (mr *MockManagerMockRecorder) SetCertificateIssuedAt(ctx, accountID, servic
}
// SetStatus mocks base method.
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error {
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status Status) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status)
ret0, _ := ret[0].(error)
@@ -209,6 +252,32 @@ func (mr *MockManagerMockRecorder) SetStatus(ctx, accountID, serviceID, status i
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetStatus", reflect.TypeOf((*MockManager)(nil).SetStatus), ctx, accountID, serviceID, status)
}
// StartExposeReaper mocks base method.
func (m *MockManager) StartExposeReaper(ctx context.Context) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "StartExposeReaper", ctx)
}
// StartExposeReaper indicates an expected call of StartExposeReaper.
func (mr *MockManagerMockRecorder) StartExposeReaper(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartExposeReaper", reflect.TypeOf((*MockManager)(nil).StartExposeReaper), ctx)
}
// StopServiceFromPeer mocks base method.
func (m *MockManager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StopServiceFromPeer", ctx, accountID, peerID, domain)
ret0, _ := ret[0].(error)
return ret0
}
// StopServiceFromPeer indicates an expected call of StopServiceFromPeer.
func (mr *MockManagerMockRecorder) StopServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopServiceFromPeer", reflect.TypeOf((*MockManager)(nil).StopServiceFromPeer), ctx, accountID, peerID, domain)
}
// UpdateService mocks base method.
func (m *MockManager) UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) {
m.ctrl.T.Helper()

View File

@@ -6,10 +6,10 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
@@ -17,11 +17,11 @@ import (
)
type handler struct {
manager reverseproxy.Manager
manager rpservice.Manager
}
// RegisterEndpoints registers all service HTTP endpoints.
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
h := &handler{
manager: manager,
}
@@ -72,8 +72,11 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
return
}
service := new(reverseproxy.Service)
service.FromAPIRequest(&req, userAuth.AccountId)
service := new(rpservice.Service)
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
if err = service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
@@ -130,9 +133,12 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
return
}
service := new(reverseproxy.Service)
service := new(rpservice.Service)
service.ID = serviceID
service.FromAPIRequest(&req, userAuth.AccountId)
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
if err = service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)

View File

@@ -0,0 +1,65 @@
package manager
import (
"context"
"math/rand/v2"
"time"
"github.com/netbirdio/netbird/shared/management/status"
log "github.com/sirupsen/logrus"
)
const (
exposeTTL = 90 * time.Second
exposeReapInterval = 30 * time.Second
maxExposesPerPeer = 10
exposeReapBatch = 100
)
type exposeReaper struct {
manager *Manager
}
// StartExposeReaper starts a background goroutine that reaps expired ephemeral services from the DB.
func (r *exposeReaper) StartExposeReaper(ctx context.Context) {
go func() {
// start with a random delay
rn := rand.IntN(10)
time.Sleep(time.Duration(rn) * time.Second)
ticker := time.NewTicker(exposeReapInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
r.reapExpiredExposes(ctx)
}
}
}()
}
func (r *exposeReaper) reapExpiredExposes(ctx context.Context) {
expired, err := r.manager.store.GetExpiredEphemeralServices(ctx, exposeTTL, exposeReapBatch)
if err != nil {
log.Errorf("failed to get expired ephemeral services: %v", err)
return
}
for _, svc := range expired {
log.Infof("reaping expired expose session for peer %s, domain %s", svc.SourcePeer, svc.Domain)
err := r.manager.deleteExpiredPeerService(ctx, svc.AccountID, svc.SourcePeer, svc.ID)
if err == nil {
continue
}
if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound {
log.Debugf("service %s was already deleted by another instance", svc.Domain)
} else {
log.Errorf("failed to delete expired peer-exposed service for domain %s: %v", svc.Domain, err)
}
}
}

View File

@@ -0,0 +1,208 @@
package manager
import (
"context"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/store"
)
func TestReapExpiredExposes(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
require.NoError(t, err)
// Manually expire the service by backdating meta_last_renewed_at
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
// Create a non-expired service
resp2, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8081,
Protocol: "http",
})
require.NoError(t, err)
mgr.exposeReaper.reapExpiredExposes(ctx)
// Expired service should be deleted
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
require.Error(t, err, "expired service should be deleted")
// Non-expired service should remain
_, err = testStore.GetServiceByDomain(ctx, resp2.Domain)
require.NoError(t, err, "active service should remain")
}
func TestReapAlreadyDeletedService(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
require.NoError(t, err)
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
// Delete the service before reaping
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
// Reaping should handle the already-deleted service gracefully
mgr.exposeReaper.reapExpiredExposes(ctx)
}
func TestConcurrentReapAndRenew(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
ctx := context.Background()
for i := range 5 {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080 + i,
Protocol: "http",
})
require.NoError(t, err)
}
// Expire all services
services, err := testStore.GetAccountServices(ctx, store.LockingStrengthNone, testAccountID)
require.NoError(t, err)
for _, svc := range services {
if svc.Source == rpservice.SourceEphemeral {
expireEphemeralService(t, testStore, testAccountID, svc.Domain)
}
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
mgr.exposeReaper.reapExpiredExposes(ctx)
}()
go func() {
defer wg.Done()
_, _ = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
}()
wg.Wait()
count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(0), count, "all expired services should be reaped")
}
func TestRenewEphemeralService(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
ctx := context.Background()
t.Run("renew succeeds for active service", func(t *testing.T) {
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8082,
Protocol: "http",
})
require.NoError(t, err)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
})
t.Run("renew fails for nonexistent domain", func(t *testing.T) {
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com")
require.Error(t, err)
assert.Contains(t, err.Error(), "no active expose session")
})
}
func TestCountAndExistsEphemeralServices(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
ctx := context.Background()
count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(0), count)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8083,
Protocol: "http",
})
require.NoError(t, err)
count, err = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(1), count)
exists, err := mgr.store.EphemeralServiceExists(ctx, store.LockingStrengthNone, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
assert.True(t, exists, "service should exist")
exists, err = mgr.store.EphemeralServiceExists(ctx, store.LockingStrengthNone, testAccountID, testPeerID, "no-such.domain")
require.NoError(t, err)
assert.False(t, exists, "non-existent service should not exist")
}
func TestMaxExposesPerPeerEnforced(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
ctx := context.Background()
for i := range maxExposesPerPeer {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8090 + i,
Protocol: "http",
})
require.NoError(t, err, "expose %d should succeed", i)
}
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 9999,
Protocol: "http",
})
require.Error(t, err)
assert.Contains(t, err.Error(), "maximum number of active expose sessions")
}
func TestReapSkipsRenewedService(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8086,
Protocol: "http",
})
require.NoError(t, err)
// Expire the service
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
// Renew it before the reaper runs
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
// Reaper should skip it because the re-check sees a fresh timestamp
mgr.exposeReaper.reapExpiredExposes(ctx)
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
require.NoError(t, err, "renewed service should survive reaping")
}
// expireEphemeralService backdates meta_last_renewed_at to force expiration.
func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) {
t.Helper()
svc, err := s.GetServiceByDomain(context.Background(), domain)
require.NoError(t, err)
expired := time.Now().Add(-2 * exposeTTL)
svc.Meta.LastRenewedAt = &expired
err = s.UpdateService(context.Background(), svc)
require.NoError(t, err)
}

View File

@@ -0,0 +1,928 @@
package manager
import (
"context"
"fmt"
"math/rand/v2"
"slices"
"time"
log "github.com/sirupsen/logrus"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
const unknownHostPlaceholder = "unknown"
// ClusterDeriver derives the proxy cluster from a domain.
type ClusterDeriver interface {
DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error)
GetClusterDomains() []string
}
type Manager struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
proxyController proxy.Controller
clusterDeriver ClusterDeriver
exposeReaper *exposeReaper
}
// NewManager creates a new service manager.
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, clusterDeriver ClusterDeriver) *Manager {
mgr := &Manager{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
proxyController: proxyController,
clusterDeriver: clusterDeriver,
}
mgr.exposeReaper = &exposeReaper{manager: mgr}
return mgr
}
// StartExposeReaper starts the background goroutine that reaps expired ephemeral services.
func (m *Manager) StartExposeReaper(ctx context.Context) {
m.exposeReaper.StartExposeReaper(ctx)
}
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
}
return services, nil
}
func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *service.Service) error {
for _, target := range s.Targets {
switch target.TargetType {
case service.TargetTypePeer:
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, s.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = peer.IP.String()
case service.TargetTypeHost:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = resource.Prefix.Addr().String()
case service.TargetTypeDomain:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = resource.Domain
case service.TargetTypeSubnet:
// For subnets we do not do any lookups on the resource
default:
return fmt.Errorf("unknown target type: %s", target.TargetType)
}
}
return nil
}
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return nil, fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
return service, nil
}
func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil {
return nil, err
}
if err := m.persistNewService(ctx, accountID, s); err != nil {
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta())
err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
}
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return s, nil
}
func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID string, service *service.Service) error {
if m.clusterDeriver != nil {
proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain)
return status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
}
service.ProxyCluster = proxyCluster
}
service.AccountID = accountID
service.InitNewRecord()
if err := service.Auth.HashSecrets(); err != nil {
return fmt.Errorf("hash secrets: %w", err)
}
keyPair, err := sessionkey.GenerateKeyPair()
if err != nil {
return fmt.Errorf("generate session keys: %w", err)
}
service.SessionPrivateKey = keyPair.PrivateKey
service.SessionPublicKey = keyPair.PublicKey
return nil
}
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.checkDomainAvailable(ctx, transaction, service.Domain, ""); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, service); err != nil {
return fmt.Errorf("failed to create service: %w", err)
}
return nil
})
}
// persistNewEphemeralService creates an ephemeral service inside a single transaction
// that also enforces the duplicate and per-peer limit checks atomically.
// The count and exists queries use FOR UPDATE locking to serialize concurrent creates
// for the same peer, preventing the per-peer limit from being bypassed.
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
// Lock the peer row to serialize concurrent creates for the same peer.
// Without this, when no ephemeral rows exist yet, FOR UPDATE on the services
// table returns no rows and acquires no locks, allowing concurrent inserts
// to bypass the per-peer limit.
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil {
return fmt.Errorf("lock peer row: %w", err)
}
exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain)
if err != nil {
return fmt.Errorf("check existing expose: %w", err)
}
if exists {
return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
}
count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID)
if err != nil {
return fmt.Errorf("count peer exposes: %w", err)
}
if count >= int64(maxExposesPerPeer) {
return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
}
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, svc); err != nil {
return fmt.Errorf("create service: %w", err)
}
return nil
})
}
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, domain, excludeServiceID string) error {
existingService, err := transaction.GetServiceByDomain(ctx, domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing service: %w", err)
}
return nil
}
if existingService != nil && existingService.ID != excludeServiceID {
return status.Errorf(status.AlreadyExists, "domain already taken")
}
return nil
}
func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err := service.Auth.HashSecrets(); err != nil {
return nil, fmt.Errorf("hash secrets: %w", err)
}
updateInfo, err := m.persistServiceUpdate(ctx, accountID, service)
if err != nil {
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceUpdated, service.EventMeta())
if err := m.replaceHostByLookup(ctx, accountID, service); err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil
}
type serviceUpdateInfo struct {
oldCluster string
domainChanged bool
serviceEnabledChanged bool
}
func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *service.Service) (*serviceUpdateInfo, error) {
var updateInfo serviceUpdateInfo
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
if err != nil {
return err
}
updateInfo.oldCluster = existingService.ProxyCluster
updateInfo.domainChanged = existingService.Domain != service.Domain
if updateInfo.domainChanged {
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
return err
}
} else {
service.ProxyCluster = existingService.ProxyCluster
}
m.preserveExistingAuthSecrets(service, existingService)
m.preserveServiceMetadata(service, existingService)
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
return nil
})
return &updateInfo, err
}
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
if err := m.checkDomainAvailable(ctx, transaction, service.Domain, service.ID); err != nil {
return err
}
if m.clusterDeriver != nil {
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
} else {
service.ProxyCluster = newCluster
}
}
return nil
}
func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) {
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
service.Auth.PasswordAuth.Password == "" {
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
}
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
service.Auth.PinAuth.Pin == "" {
service.Auth.PinAuth = existingService.Auth.PinAuth
}
}
func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) {
service.Meta = existingService.Meta
service.SessionPrivateKey = existingService.SessionPrivateKey
service.SessionPublicKey = existingService.SessionPublicKey
}
func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) {
oidcCfg := m.proxyController.GetOIDCValidationConfig()
switch {
case updateInfo.domainChanged && updateInfo.oldCluster != s.ProxyCluster:
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), updateInfo.oldCluster)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
case !s.Enabled && updateInfo.serviceEnabledChanged:
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), s.ProxyCluster)
case s.Enabled && updateInfo.serviceEnabledChanged:
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
default:
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster)
}
}
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*service.Target) error {
for _, target := range targets {
switch target.TargetType {
case service.TargetTypePeer:
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
}
case service.TargetTypeHost, service.TargetTypeSubnet, service.TargetTypeDomain:
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
}
}
}
return nil
}
func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
var s *service.Service
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return err
}
if err = transaction.DeleteServiceTargets(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("failed to delete targets: %w", err)
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("failed to delete service: %w", err)
}
return nil
})
if err != nil {
return err
}
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, s.EventMeta())
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
var services []*service.Service
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
services, err = transaction.GetAccountServices(ctx, store.LockingStrengthUpdate, accountID)
if err != nil {
return err
}
for _, svc := range services {
if err = transaction.DeleteService(ctx, accountID, svc.ID); err != nil {
return fmt.Errorf("failed to delete service: %w", err)
}
}
return nil
})
if err != nil {
return err
}
oidcCfg := m.proxyController.GetOIDCValidationConfig()
for _, svc := range services {
m.accountManager.StoreEvent(ctx, userID, svc.ID, accountID, activity.ServiceDeleted, svc.EventMeta())
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", oidcCfg), svc.ProxyCluster)
}
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
// Call this when receiving a gRPC notification that the certificate was issued.
func (m *Manager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get service: %w", err)
}
now := time.Now()
service.Meta.CertificateIssuedAt = &now
if err = transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("failed to update service certificate timestamp: %w", err)
}
return nil
})
}
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
func (m *Manager) SetStatus(ctx context.Context, accountID, serviceID string, status service.Status) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get service: %w", err)
}
service.Meta.Status = string(status)
if err = transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("failed to update service status: %w", err)
}
return nil
})
}
func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string) error {
s, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
}
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return fmt.Errorf("failed to get services: %w", err)
}
for _, s := range services {
err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
}
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
}
return nil
}
func (m *Manager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, service.AccountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
}
return services, nil
}
func (m *Manager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*service.Service, error) {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return nil, fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
return service, nil
}
func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
}
return services, nil
}
func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
return "", nil
}
return "", fmt.Errorf("failed to get service target by resource ID: %w", err)
}
if target == nil {
return "", nil
}
return target.ServiceID, nil
}
// validateExposePermission checks whether the peer is allowed to use the expose feature.
// It verifies the account has peer expose enabled and that the peer belongs to an allowed group.
func (m *Manager) validateExposePermission(ctx context.Context, accountID, peerID string) error {
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return status.Errorf(status.Internal, "get account settings: %v", err)
}
if !settings.PeerExposeEnabled {
return status.Errorf(status.PermissionDenied, "peer expose is not enabled for this account")
}
if len(settings.PeerExposeGroups) == 0 {
return status.Errorf(status.PermissionDenied, "no group is set for peer expose")
}
peerGroupIDs, err := m.store.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peer group IDs: %v", err)
return status.Errorf(status.Internal, "get peer groups: %v", err)
}
for _, pg := range peerGroupIDs {
if slices.Contains(settings.PeerExposeGroups, pg) {
return nil
}
}
return status.Errorf(status.PermissionDenied, "peer is not in an allowed expose group")
}
// CreateServiceFromPeer creates a service initiated by a peer expose request.
// It validates the request, checks expose permissions, enforces the per-peer limit,
// creates the service, and tracks it for TTL-based reaping.
func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
if err := req.Validate(); err != nil {
return nil, status.Errorf(status.InvalidArgument, "validate expose request: %v", err)
}
if err := m.validateExposePermission(ctx, accountID, peerID); err != nil {
return nil, err
}
serviceName, err := service.GenerateExposeName(req.NamePrefix)
if err != nil {
return nil, status.Errorf(status.InvalidArgument, "generate service name: %v", err)
}
svc := req.ToService(accountID, peerID, serviceName)
svc.Source = service.SourceEphemeral
if svc.Domain == "" {
domain, err := m.buildRandomDomain(svc.Name)
if err != nil {
return nil, fmt.Errorf("build random domain for service %s: %w", svc.Name, err)
}
svc.Domain = domain
}
if svc.Auth.BearerAuth != nil && svc.Auth.BearerAuth.Enabled {
groupIDs, err := m.getGroupIDsFromNames(ctx, accountID, svc.Auth.BearerAuth.DistributionGroups)
if err != nil {
return nil, fmt.Errorf("get group ids for service %s: %w", svc.Name, err)
}
svc.Auth.BearerAuth.DistributionGroups = groupIDs
}
if err := m.initializeServiceForCreate(ctx, accountID, svc); err != nil {
return nil, err
}
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return nil, err
}
svc.SourcePeer = peerID
now := time.Now()
svc.Meta.LastRenewedAt = &now
if err := m.persistNewEphemeralService(ctx, accountID, peerID, svc); err != nil {
return nil, err
}
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
m.accountManager.StoreEvent(ctx, peerID, svc.ID, accountID, activity.PeerServiceExposed, meta)
if err := m.replaceHostByLookup(ctx, accountID, svc); err != nil {
return nil, fmt.Errorf("replace host by lookup for service %s: %w", svc.ID, err)
}
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return &service.ExposeServiceResponse{
ServiceName: svc.Name,
ServiceURL: "https://" + svc.Domain,
Domain: svc.Domain,
}, nil
}
func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, groupNames []string) ([]string, error) {
if len(groupNames) == 0 {
return []string{}, fmt.Errorf("no group names provided")
}
groupIDs := make([]string, 0, len(groupNames))
for _, groupName := range groupNames {
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err)
}
groupIDs = append(groupIDs, g.ID)
}
return groupIDs, nil
}
func (m *Manager) buildRandomDomain(name string) (string, error) {
if m.clusterDeriver == nil {
return "", fmt.Errorf("unable to get random domain")
}
clusterDomains := m.clusterDeriver.GetClusterDomains()
if len(clusterDomains) == 0 {
return "", fmt.Errorf("no cluster domains found for service %s", name)
}
index := rand.IntN(len(clusterDomains))
domain := name + "." + clusterDomains[index]
return domain, nil
}
// RenewServiceFromPeer updates the DB timestamp for the peer's ephemeral service.
func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
return m.store.RenewEphemeralService(ctx, accountID, peerID, domain)
}
// StopServiceFromPeer stops a peer's active expose session by deleting the service from the DB.
func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err)
return err
}
return nil
}
// deleteServiceFromPeer deletes a peer-initiated service identified by domain.
// When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed.
func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error {
svc, err := m.lookupPeerService(ctx, accountID, peerID, domain)
if err != nil {
return err
}
activityCode := activity.PeerServiceUnexposed
if expired {
activityCode = activity.PeerServiceExposeExpired
}
return m.deletePeerService(ctx, accountID, peerID, svc.ID, activityCode)
}
// lookupPeerService finds a peer-initiated service by domain and validates ownership.
func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) {
svc, err := m.store.GetServiceByDomain(ctx, domain)
if err != nil {
return nil, err
}
if svc.Source != service.SourceEphemeral {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on API-created service via peer expose")
}
if svc.SourcePeer != peerID {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on service exposed by another peer")
}
return svc, nil
}
func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error {
var svc *service.Service
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
svc, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return err
}
if svc.Source != service.SourceEphemeral {
return status.Errorf(status.PermissionDenied, "cannot delete API-created service via peer expose")
}
if svc.SourcePeer != peerID {
return status.Errorf(status.PermissionDenied, "cannot delete service exposed by another peer")
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("delete service: %w", err)
}
return nil
})
if err != nil {
return err
}
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get peer %s for event metadata: %v", peerID, err)
peer = nil
}
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activityCode, meta)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
// deleteExpiredPeerService deletes an ephemeral service by ID after re-checking
// that it is still expired under a row lock. This prevents deleting a service
// that was renewed between the batch query and this delete, and ensures only one
// management instance processes the deletion
func (m *Manager) deleteExpiredPeerService(ctx context.Context, accountID, peerID, serviceID string) error {
var svc *service.Service
deleted := false
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
svc, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return err
}
if svc.Source != service.SourceEphemeral || svc.SourcePeer != peerID {
return status.Errorf(status.PermissionDenied, "service does not match expected ephemeral owner")
}
if svc.Meta.LastRenewedAt != nil && time.Since(*svc.Meta.LastRenewedAt) <= exposeTTL {
return nil
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("delete service: %w", err)
}
deleted = true
return nil
})
if err != nil {
return err
}
if !deleted {
return nil
}
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get peer %s for event metadata: %v", peerID, err)
peer = nil
}
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activity.PeerServiceExposeExpired, meta)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func addPeerInfoToEventMeta(meta map[string]any, peer *nbpeer.Peer) map[string]any {
if peer == nil {
return meta
}
meta["peer_name"] = peer.Name
if peer.IP != nil {
meta["peer_ip"] = peer.IP.String()
}
return meta
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,817 @@
package service
import (
"crypto/rand"
"errors"
"fmt"
"math/big"
"net"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/util/crypt"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/proto"
)
type Operation string
const (
Create Operation = "create"
Update Operation = "update"
Delete Operation = "delete"
)
type Status string
const (
StatusPending Status = "pending"
StatusActive Status = "active"
StatusTunnelNotCreated Status = "tunnel_not_created"
StatusCertificatePending Status = "certificate_pending"
StatusCertificateFailed Status = "certificate_failed"
StatusError Status = "error"
TargetTypePeer = "peer"
TargetTypeHost = "host"
TargetTypeDomain = "domain"
TargetTypeSubnet = "subnet"
SourcePermanent = "permanent"
SourceEphemeral = "ephemeral"
)
type TargetOptions struct {
SkipTLSVerify bool `json:"skip_tls_verify"`
RequestTimeout time.Duration `json:"request_timeout,omitempty"`
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
}
type Target struct {
ID uint `gorm:"primaryKey" json:"-"`
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
Path *string `json:"path,omitempty"`
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
Port int `gorm:"index:idx_target_port" json:"port"`
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
Options TargetOptions `gorm:"embedded" json:"options"`
}
type PasswordAuthConfig struct {
Enabled bool `json:"enabled"`
Password string `json:"password"`
}
type PINAuthConfig struct {
Enabled bool `json:"enabled"`
Pin string `json:"pin"`
}
type BearerAuthConfig struct {
Enabled bool `json:"enabled"`
DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"`
}
type AuthConfig struct {
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"`
PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"`
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
}
func (a *AuthConfig) HashSecrets() error {
if a.PasswordAuth != nil && a.PasswordAuth.Enabled && a.PasswordAuth.Password != "" {
hashedPassword, err := argon2id.Hash(a.PasswordAuth.Password)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
a.PasswordAuth.Password = hashedPassword
}
if a.PinAuth != nil && a.PinAuth.Enabled && a.PinAuth.Pin != "" {
hashedPin, err := argon2id.Hash(a.PinAuth.Pin)
if err != nil {
return fmt.Errorf("hash pin: %w", err)
}
a.PinAuth.Pin = hashedPin
}
return nil
}
func (a *AuthConfig) ClearSecrets() {
if a.PasswordAuth != nil {
a.PasswordAuth.Password = ""
}
if a.PinAuth != nil {
a.PinAuth.Pin = ""
}
}
type Meta struct {
CreatedAt time.Time
CertificateIssuedAt *time.Time
Status string
LastRenewedAt *time.Time
}
type Service struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
Name string
Domain string `gorm:"type:varchar(255);uniqueIndex"`
ProxyCluster string `gorm:"index"`
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
Enabled bool
PassHostHeader bool
RewriteRedirects bool
Auth AuthConfig `gorm:"serializer:json"`
Meta Meta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"`
SessionPublicKey string `gorm:"column:session_public_key"`
Source string `gorm:"default:'permanent';index:idx_service_source_peer"`
SourcePeer string `gorm:"index:idx_service_source_peer"`
}
func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service {
for _, target := range targets {
target.AccountID = accountID
}
s := &Service{
AccountID: accountID,
Name: name,
Domain: domain,
ProxyCluster: proxyCluster,
Targets: targets,
Enabled: enabled,
}
s.InitNewRecord()
return s
}
// InitNewRecord generates a new unique ID and resets metadata for a newly created
// Service record. This overwrites any existing ID and Meta fields and should
// only be called during initial creation, not for updates.
func (s *Service) InitNewRecord() {
s.ID = xid.New().String()
s.Meta = Meta{
CreatedAt: time.Now(),
Status: string(StatusPending),
}
}
func (s *Service) ToAPIResponse() *api.Service {
s.Auth.ClearSecrets()
authConfig := api.ServiceAuthConfig{}
if s.Auth.PasswordAuth != nil {
authConfig.PasswordAuth = &api.PasswordAuthConfig{
Enabled: s.Auth.PasswordAuth.Enabled,
Password: s.Auth.PasswordAuth.Password,
}
}
if s.Auth.PinAuth != nil {
authConfig.PinAuth = &api.PINAuthConfig{
Enabled: s.Auth.PinAuth.Enabled,
Pin: s.Auth.PinAuth.Pin,
}
}
if s.Auth.BearerAuth != nil {
authConfig.BearerAuth = &api.BearerAuthConfig{
Enabled: s.Auth.BearerAuth.Enabled,
DistributionGroups: &s.Auth.BearerAuth.DistributionGroups,
}
}
// Convert internal targets to API targets
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
for _, target := range s.Targets {
st := api.ServiceTarget{
Path: target.Path,
Host: &target.Host,
Port: target.Port,
Protocol: api.ServiceTargetProtocol(target.Protocol),
TargetId: target.TargetId,
TargetType: api.ServiceTargetTargetType(target.TargetType),
Enabled: target.Enabled,
}
st.Options = targetOptionsToAPI(target.Options)
apiTargets = append(apiTargets, st)
}
meta := api.ServiceMeta{
CreatedAt: s.Meta.CreatedAt,
Status: api.ServiceMetaStatus(s.Meta.Status),
}
if s.Meta.CertificateIssuedAt != nil {
meta.CertificateIssuedAt = s.Meta.CertificateIssuedAt
}
resp := &api.Service{
Id: s.ID,
Name: s.Name,
Domain: s.Domain,
Targets: apiTargets,
Enabled: s.Enabled,
PassHostHeader: &s.PassHostHeader,
RewriteRedirects: &s.RewriteRedirects,
Auth: authConfig,
Meta: meta,
}
if s.ProxyCluster != "" {
resp.ProxyCluster = &s.ProxyCluster
}
return resp
}
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping {
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
for _, target := range s.Targets {
if !target.Enabled {
continue
}
// TODO: Make path prefix stripping configurable per-target.
// Currently the matching prefix is baked into the target URL path,
// so the proxy strips-then-re-adds it (effectively a no-op).
targetURL := url.URL{
Scheme: target.Protocol,
Host: target.Host,
Path: "/", // TODO: support service path
}
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port))
}
path := "/"
if target.Path != nil {
path = *target.Path
}
pm := &proto.PathMapping{
Path: path,
Target: targetURL.String(),
}
pm.Options = targetOptionsToProto(target.Options)
pathMappings = append(pathMappings, pm)
}
auth := &proto.Authentication{
SessionKey: s.SessionPublicKey,
MaxSessionAgeSeconds: int64((time.Hour * 24).Seconds()),
}
if s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled {
auth.Password = true
}
if s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled {
auth.Pin = true
}
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
auth.Oidc = true
}
return &proto.ProxyMapping{
Type: operationToProtoType(operation),
Id: s.ID,
Domain: s.Domain,
Path: pathMappings,
AuthToken: authToken,
Auth: auth,
AccountId: s.AccountID,
PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects,
}
}
func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
switch op {
case Create:
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
case Update:
return proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED
case Delete:
return proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED
default:
log.Fatalf("unknown operation type: %v", op)
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
}
}
// isDefaultPort reports whether port is the standard default for the given scheme
// (443 for https, 80 for http).
func isDefaultPort(scheme string, port int) bool {
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
}
// PathRewriteMode controls how the request path is rewritten before forwarding.
type PathRewriteMode string
const (
PathRewritePreserve PathRewriteMode = "preserve"
)
func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode {
switch mode {
case PathRewritePreserve:
return proto.PathRewriteMode_PATH_REWRITE_PRESERVE
default:
return proto.PathRewriteMode_PATH_REWRITE_DEFAULT
}
}
func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 {
return nil
}
apiOpts := &api.ServiceTargetOptions{}
if opts.SkipTLSVerify {
apiOpts.SkipTlsVerify = &opts.SkipTLSVerify
}
if opts.RequestTimeout != 0 {
s := opts.RequestTimeout.String()
apiOpts.RequestTimeout = &s
}
if opts.PathRewrite != "" {
pr := api.ServiceTargetOptionsPathRewrite(opts.PathRewrite)
apiOpts.PathRewrite = &pr
}
if len(opts.CustomHeaders) > 0 {
apiOpts.CustomHeaders = &opts.CustomHeaders
}
return apiOpts
}
func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 && len(opts.CustomHeaders) == 0 {
return nil
}
popts := &proto.PathTargetOptions{
SkipTlsVerify: opts.SkipTLSVerify,
PathRewrite: pathRewriteToProto(opts.PathRewrite),
CustomHeaders: opts.CustomHeaders,
}
if opts.RequestTimeout != 0 {
popts.RequestTimeout = durationpb.New(opts.RequestTimeout)
}
return popts
}
func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions, error) {
var opts TargetOptions
if o.SkipTlsVerify != nil {
opts.SkipTLSVerify = *o.SkipTlsVerify
}
if o.RequestTimeout != nil {
d, err := time.ParseDuration(*o.RequestTimeout)
if err != nil {
return opts, fmt.Errorf("target %d: parse request_timeout %q: %w", idx, *o.RequestTimeout, err)
}
opts.RequestTimeout = d
}
if o.PathRewrite != nil {
opts.PathRewrite = PathRewriteMode(*o.PathRewrite)
}
if o.CustomHeaders != nil {
opts.CustomHeaders = *o.CustomHeaders
}
return opts, nil
}
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) error {
s.Name = req.Name
s.Domain = req.Domain
s.AccountID = accountID
targets := make([]*Target, 0, len(req.Targets))
for i, apiTarget := range req.Targets {
target := &Target{
AccountID: accountID,
Path: apiTarget.Path,
Port: apiTarget.Port,
Protocol: string(apiTarget.Protocol),
TargetId: apiTarget.TargetId,
TargetType: string(apiTarget.TargetType),
Enabled: apiTarget.Enabled,
}
if apiTarget.Host != nil {
target.Host = *apiTarget.Host
}
if apiTarget.Options != nil {
opts, err := targetOptionsFromAPI(i, apiTarget.Options)
if err != nil {
return err
}
target.Options = opts
}
targets = append(targets, target)
}
s.Targets = targets
s.Enabled = req.Enabled
if req.PassHostHeader != nil {
s.PassHostHeader = *req.PassHostHeader
}
if req.RewriteRedirects != nil {
s.RewriteRedirects = *req.RewriteRedirects
}
if req.Auth.PasswordAuth != nil {
s.Auth.PasswordAuth = &PasswordAuthConfig{
Enabled: req.Auth.PasswordAuth.Enabled,
Password: req.Auth.PasswordAuth.Password,
}
}
if req.Auth.PinAuth != nil {
s.Auth.PinAuth = &PINAuthConfig{
Enabled: req.Auth.PinAuth.Enabled,
Pin: req.Auth.PinAuth.Pin,
}
}
if req.Auth.BearerAuth != nil {
bearerAuth := &BearerAuthConfig{
Enabled: req.Auth.BearerAuth.Enabled,
}
if req.Auth.BearerAuth.DistributionGroups != nil {
bearerAuth.DistributionGroups = *req.Auth.BearerAuth.DistributionGroups
}
s.Auth.BearerAuth = bearerAuth
}
return nil
}
func (s *Service) Validate() error {
if s.Name == "" {
return errors.New("service name is required")
}
if len(s.Name) > 255 {
return errors.New("service name exceeds maximum length of 255 characters")
}
if s.Domain == "" {
return errors.New("service domain is required")
}
if len(s.Targets) == 0 {
return errors.New("at least one target is required")
}
for i, target := range s.Targets {
switch target.TargetType {
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
// host field will be ignored
case TargetTypeSubnet:
if target.Host == "" {
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
}
default:
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
}
if target.TargetId == "" {
return fmt.Errorf("target %d has empty target_id", i)
}
if err := validateTargetOptions(i, &target.Options); err != nil {
return err
}
}
return nil
}
const (
maxRequestTimeout = 5 * time.Minute
maxCustomHeaders = 16
maxHeaderKeyLen = 128
maxHeaderValueLen = 4096
)
// httpHeaderNameRe matches valid HTTP header field names per RFC 7230 token definition.
var httpHeaderNameRe = regexp.MustCompile(`^[!#$%&'*+\-.^_` + "`" + `|~0-9A-Za-z]+$`)
// hopByHopHeaders are headers that must not be set as custom headers
// because they are connection-level and stripped by the proxy.
var hopByHopHeaders = map[string]struct{}{
"Connection": {},
"Keep-Alive": {},
"Proxy-Authenticate": {},
"Proxy-Authorization": {},
"Proxy-Connection": {},
"Te": {},
"Trailer": {},
"Transfer-Encoding": {},
"Upgrade": {},
}
// reservedHeaders are set authoritatively by the proxy or control HTTP framing
// and cannot be overridden.
var reservedHeaders = map[string]struct{}{
"Content-Length": {},
"Content-Type": {},
"Cookie": {},
"Forwarded": {},
"X-Forwarded-For": {},
"X-Forwarded-Host": {},
"X-Forwarded-Port": {},
"X-Forwarded-Proto": {},
"X-Real-Ip": {},
}
func validateTargetOptions(idx int, opts *TargetOptions) error {
if opts.PathRewrite != "" && opts.PathRewrite != PathRewritePreserve {
return fmt.Errorf("target %d: unknown path_rewrite mode %q", idx, opts.PathRewrite)
}
if opts.RequestTimeout != 0 {
if opts.RequestTimeout <= 0 {
return fmt.Errorf("target %d: request_timeout must be positive", idx)
}
if opts.RequestTimeout > maxRequestTimeout {
return fmt.Errorf("target %d: request_timeout exceeds maximum of %s", idx, maxRequestTimeout)
}
}
if err := validateCustomHeaders(idx, opts.CustomHeaders); err != nil {
return err
}
return nil
}
func validateCustomHeaders(idx int, headers map[string]string) error {
if len(headers) > maxCustomHeaders {
return fmt.Errorf("target %d: custom_headers count %d exceeds maximum of %d", idx, len(headers), maxCustomHeaders)
}
seen := make(map[string]string, len(headers))
for key, value := range headers {
if !httpHeaderNameRe.MatchString(key) {
return fmt.Errorf("target %d: custom header key %q is not a valid HTTP header name", idx, key)
}
if len(key) > maxHeaderKeyLen {
return fmt.Errorf("target %d: custom header key %q exceeds maximum length of %d", idx, key, maxHeaderKeyLen)
}
if len(value) > maxHeaderValueLen {
return fmt.Errorf("target %d: custom header %q value exceeds maximum length of %d", idx, key, maxHeaderValueLen)
}
if containsCRLF(key) || containsCRLF(value) {
return fmt.Errorf("target %d: custom header %q contains invalid characters", idx, key)
}
canonical := http.CanonicalHeaderKey(key)
if prev, ok := seen[canonical]; ok {
return fmt.Errorf("target %d: custom header keys %q and %q collide (both canonicalize to %q)", idx, prev, key, canonical)
}
seen[canonical] = key
if _, ok := hopByHopHeaders[canonical]; ok {
return fmt.Errorf("target %d: custom header %q is a hop-by-hop header and cannot be set", idx, key)
}
if _, ok := reservedHeaders[canonical]; ok {
return fmt.Errorf("target %d: custom header %q is managed by the proxy and cannot be overridden", idx, key)
}
if canonical == "Host" {
return fmt.Errorf("target %d: use pass_host_header instead of setting Host as a custom header", idx)
}
}
return nil
}
func containsCRLF(s string) bool {
return strings.ContainsAny(s, "\r\n")
}
func (s *Service) EventMeta() map[string]any {
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster, "source": s.Source, "auth": s.isAuthEnabled()}
}
func (s *Service) isAuthEnabled() bool {
return s.Auth.PasswordAuth != nil || s.Auth.PinAuth != nil || s.Auth.BearerAuth != nil
}
func (s *Service) Copy() *Service {
targets := make([]*Target, len(s.Targets))
for i, target := range s.Targets {
targetCopy := *target
if len(target.Options.CustomHeaders) > 0 {
targetCopy.Options.CustomHeaders = make(map[string]string, len(target.Options.CustomHeaders))
for k, v := range target.Options.CustomHeaders {
targetCopy.Options.CustomHeaders[k] = v
}
}
targets[i] = &targetCopy
}
return &Service{
ID: s.ID,
AccountID: s.AccountID,
Name: s.Name,
Domain: s.Domain,
ProxyCluster: s.ProxyCluster,
Targets: targets,
Enabled: s.Enabled,
PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects,
Auth: s.Auth,
Meta: s.Meta,
SessionPrivateKey: s.SessionPrivateKey,
SessionPublicKey: s.SessionPublicKey,
Source: s.Source,
SourcePeer: s.SourcePeer,
}
}
func (s *Service) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
if s.SessionPrivateKey != "" {
var err error
s.SessionPrivateKey, err = enc.Encrypt(s.SessionPrivateKey)
if err != nil {
return err
}
}
return nil
}
func (s *Service) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
if s.SessionPrivateKey != "" {
var err error
s.SessionPrivateKey, err = enc.Decrypt(s.SessionPrivateKey)
if err != nil {
return err
}
}
return nil
}
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
const alphanumCharset = "abcdefghijklmnopqrstuvwxyz0123456789"
var validNamePrefix = regexp.MustCompile(`^[a-z0-9]([a-z0-9-]{0,30}[a-z0-9])?$`)
// ExposeServiceRequest contains the parameters for creating a peer-initiated expose service.
type ExposeServiceRequest struct {
NamePrefix string
Port int
Protocol string
Domain string
Pin string
Password string
UserGroups []string
}
// Validate checks all fields of the expose request.
func (r *ExposeServiceRequest) Validate() error {
if r == nil {
return errors.New("request cannot be nil")
}
if r.Port < 1 || r.Port > 65535 {
return fmt.Errorf("port must be between 1 and 65535, got %d", r.Port)
}
if r.Protocol != "http" && r.Protocol != "https" {
return fmt.Errorf("unsupported protocol %q: must be http or https", r.Protocol)
}
if r.Pin != "" && !pinRegexp.MatchString(r.Pin) {
return errors.New("invalid pin: must be exactly 6 digits")
}
for _, g := range r.UserGroups {
if g == "" {
return errors.New("user group name cannot be empty")
}
}
if r.NamePrefix != "" && !validNamePrefix.MatchString(r.NamePrefix) {
return fmt.Errorf("invalid name prefix %q: must be lowercase alphanumeric with optional hyphens, 1-32 characters", r.NamePrefix)
}
return nil
}
// ToService builds a Service from the expose request.
func (r *ExposeServiceRequest) ToService(accountID, peerID, serviceName string) *Service {
service := &Service{
AccountID: accountID,
Name: serviceName,
Enabled: true,
Targets: []*Target{
{
AccountID: accountID,
Port: r.Port,
Protocol: r.Protocol,
TargetId: peerID,
TargetType: TargetTypePeer,
Enabled: true,
},
},
}
if r.Domain != "" {
service.Domain = serviceName + "." + r.Domain
}
if r.Pin != "" {
service.Auth.PinAuth = &PINAuthConfig{
Enabled: true,
Pin: r.Pin,
}
}
if r.Password != "" {
service.Auth.PasswordAuth = &PasswordAuthConfig{
Enabled: true,
Password: r.Password,
}
}
if len(r.UserGroups) > 0 {
service.Auth.BearerAuth = &BearerAuthConfig{
Enabled: true,
DistributionGroups: r.UserGroups,
}
}
return service
}
// ExposeServiceResponse contains the result of a successful peer expose creation.
type ExposeServiceResponse struct {
ServiceName string
ServiceURL string
Domain string
}
// GenerateExposeName generates a random service name for peer-exposed services.
// The prefix, if provided, must be a valid DNS label component (lowercase alphanumeric and hyphens).
func GenerateExposeName(prefix string) (string, error) {
if prefix != "" && !validNamePrefix.MatchString(prefix) {
return "", fmt.Errorf("invalid name prefix %q: must be lowercase alphanumeric with optional hyphens, 1-32 characters", prefix)
}
suffixLen := 12
if prefix != "" {
suffixLen = 4
}
suffix, err := randomAlphanumeric(suffixLen)
if err != nil {
return "", fmt.Errorf("generate random name: %w", err)
}
if prefix == "" {
return suffix, nil
}
return prefix + "-" + suffix, nil
}
func randomAlphanumeric(n int) (string, error) {
result := make([]byte, n)
charsetLen := big.NewInt(int64(len(alphanumCharset)))
for i := range result {
idx, err := rand.Int(rand.Reader, charsetLen)
if err != nil {
return "", err
}
result[i] = alphanumCharset[idx.Int64()]
}
return string(result), nil
}

View File

@@ -0,0 +1,732 @@
package service
import (
"errors"
"fmt"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/shared/management/proto"
)
func validProxy() *Service {
return &Service{
Name: "test",
Domain: "example.com",
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 80, Protocol: "http", Enabled: true},
},
}
}
func TestValidate_Valid(t *testing.T) {
require.NoError(t, validProxy().Validate())
}
func TestValidate_EmptyName(t *testing.T) {
rp := validProxy()
rp.Name = ""
assert.ErrorContains(t, rp.Validate(), "name is required")
}
func TestValidate_EmptyDomain(t *testing.T) {
rp := validProxy()
rp.Domain = ""
assert.ErrorContains(t, rp.Validate(), "domain is required")
}
func TestValidate_NoTargets(t *testing.T) {
rp := validProxy()
rp.Targets = nil
assert.ErrorContains(t, rp.Validate(), "at least one target")
}
func TestValidate_EmptyTargetId(t *testing.T) {
rp := validProxy()
rp.Targets[0].TargetId = ""
assert.ErrorContains(t, rp.Validate(), "empty target_id")
}
func TestValidate_InvalidTargetType(t *testing.T) {
rp := validProxy()
rp.Targets[0].TargetType = "invalid"
assert.ErrorContains(t, rp.Validate(), "invalid target_type")
}
func TestValidate_ResourceTarget(t *testing.T) {
rp := validProxy()
rp.Targets = append(rp.Targets, &Target{
TargetId: "resource-1",
TargetType: TargetTypeHost,
Host: "example.org",
Port: 443,
Protocol: "https",
Enabled: true,
})
require.NoError(t, rp.Validate())
}
func TestValidate_MultipleTargetsOneInvalid(t *testing.T) {
rp := validProxy()
rp.Targets = append(rp.Targets, &Target{
TargetId: "",
TargetType: TargetTypePeer,
Host: "10.0.0.2",
Port: 80,
Protocol: "http",
Enabled: true,
})
err := rp.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "target 1")
assert.Contains(t, err.Error(), "empty target_id")
}
func TestValidateTargetOptions_PathRewrite(t *testing.T) {
tests := []struct {
name string
mode PathRewriteMode
wantErr string
}{
{"empty is default", "", ""},
{"preserve is valid", PathRewritePreserve, ""},
{"unknown rejected", "regex", "unknown path_rewrite mode"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.PathRewrite = tt.mode
err := rp.Validate()
if tt.wantErr == "" {
assert.NoError(t, err)
} else {
assert.ErrorContains(t, err, tt.wantErr)
}
})
}
}
func TestValidateTargetOptions_RequestTimeout(t *testing.T) {
tests := []struct {
name string
timeout time.Duration
wantErr string
}{
{"valid 30s", 30 * time.Second, ""},
{"valid 2m", 2 * time.Minute, ""},
{"zero is fine", 0, ""},
{"negative", -1 * time.Second, "must be positive"},
{"exceeds max", 10 * time.Minute, "exceeds maximum"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.RequestTimeout = tt.timeout
err := rp.Validate()
if tt.wantErr == "" {
assert.NoError(t, err)
} else {
assert.ErrorContains(t, err, tt.wantErr)
}
})
}
}
func TestValidateTargetOptions_CustomHeaders(t *testing.T) {
t.Run("valid headers", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{
"X-Custom": "value",
"X-Trace": "abc123",
}
assert.NoError(t, rp.Validate())
})
t.Run("CRLF in key", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Bad\r\nKey": "value"}
assert.ErrorContains(t, rp.Validate(), "not a valid HTTP header name")
})
t.Run("CRLF in value", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Good": "bad\nvalue"}
assert.ErrorContains(t, rp.Validate(), "invalid characters")
})
t.Run("hop-by-hop header rejected", func(t *testing.T) {
for _, h := range []string{"Connection", "Transfer-Encoding", "Keep-Alive", "Upgrade", "Proxy-Connection"} {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{h: "value"}
assert.ErrorContains(t, rp.Validate(), "hop-by-hop", "header %q should be rejected", h)
}
})
t.Run("reserved header rejected", func(t *testing.T) {
for _, h := range []string{"X-Forwarded-For", "X-Real-IP", "X-Forwarded-Proto", "X-Forwarded-Host", "X-Forwarded-Port", "Cookie", "Forwarded", "Content-Length", "Content-Type"} {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{h: "value"}
assert.ErrorContains(t, rp.Validate(), "managed by the proxy", "header %q should be rejected", h)
}
})
t.Run("Host header rejected", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{"Host": "evil.com"}
assert.ErrorContains(t, rp.Validate(), "pass_host_header")
})
t.Run("too many headers", func(t *testing.T) {
rp := validProxy()
headers := make(map[string]string, 17)
for i := range 17 {
headers[fmt.Sprintf("X-H%d", i)] = "v"
}
rp.Targets[0].Options.CustomHeaders = headers
assert.ErrorContains(t, rp.Validate(), "exceeds maximum of 16")
})
t.Run("key too long", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{strings.Repeat("X", 129): "v"}
assert.ErrorContains(t, rp.Validate(), "key")
assert.ErrorContains(t, rp.Validate(), "exceeds maximum length")
})
t.Run("value too long", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Ok": strings.Repeat("v", 4097)}
assert.ErrorContains(t, rp.Validate(), "value exceeds maximum length")
})
t.Run("duplicate canonical keys rejected", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{
"x-custom": "a",
"X-Custom": "b",
}
assert.ErrorContains(t, rp.Validate(), "collide")
})
}
func TestToProtoMapping_TargetOptions(t *testing.T) {
rp := &Service{
ID: "svc-1",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{
TargetId: "peer-1",
TargetType: TargetTypePeer,
Host: "10.0.0.1",
Port: 8080,
Protocol: "http",
Enabled: true,
Options: TargetOptions{
SkipTLSVerify: true,
RequestTimeout: 30 * time.Second,
PathRewrite: PathRewritePreserve,
CustomHeaders: map[string]string{"X-Custom": "val"},
},
},
},
}
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
require.Len(t, pm.Path, 1)
opts := pm.Path[0].Options
require.NotNil(t, opts, "options should be populated")
assert.True(t, opts.SkipTlsVerify)
assert.Equal(t, proto.PathRewriteMode_PATH_REWRITE_PRESERVE, opts.PathRewrite)
assert.Equal(t, map[string]string{"X-Custom": "val"}, opts.CustomHeaders)
require.NotNil(t, opts.RequestTimeout)
assert.Equal(t, int64(30), opts.RequestTimeout.Seconds)
}
func TestToProtoMapping_NoOptionsWhenDefault(t *testing.T) {
rp := &Service{
ID: "svc-1",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{
TargetId: "peer-1",
TargetType: TargetTypePeer,
Host: "10.0.0.1",
Port: 8080,
Protocol: "http",
Enabled: true,
},
},
}
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
require.Len(t, pm.Path, 1)
assert.Nil(t, pm.Path[0].Options, "options should be nil when all defaults")
}
func TestIsDefaultPort(t *testing.T) {
tests := []struct {
scheme string
port int
want bool
}{
{"http", 80, true},
{"https", 443, true},
{"http", 443, false},
{"https", 80, false},
{"http", 8080, false},
{"https", 8443, false},
{"http", 0, false},
{"https", 0, false},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("%s/%d", tt.scheme, tt.port), func(t *testing.T) {
assert.Equal(t, tt.want, isDefaultPort(tt.scheme, tt.port))
})
}
}
func TestToProtoMapping_PortInTargetURL(t *testing.T) {
oidcConfig := proxy.OIDCValidationConfig{}
tests := []struct {
name string
protocol string
host string
port int
wantTarget string
}{
{
name: "http with default port 80 omits port",
protocol: "http",
host: "10.0.0.1",
port: 80,
wantTarget: "http://10.0.0.1/",
},
{
name: "https with default port 443 omits port",
protocol: "https",
host: "10.0.0.1",
port: 443,
wantTarget: "https://10.0.0.1/",
},
{
name: "port 0 omits port",
protocol: "http",
host: "10.0.0.1",
port: 0,
wantTarget: "http://10.0.0.1/",
},
{
name: "non-default port is included",
protocol: "http",
host: "10.0.0.1",
port: 8080,
wantTarget: "http://10.0.0.1:8080/",
},
{
name: "https with non-default port is included",
protocol: "https",
host: "10.0.0.1",
port: 8443,
wantTarget: "https://10.0.0.1:8443/",
},
{
name: "http port 443 is included",
protocol: "http",
host: "10.0.0.1",
port: 443,
wantTarget: "http://10.0.0.1:443/",
},
{
name: "https port 80 is included",
protocol: "https",
host: "10.0.0.1",
port: 80,
wantTarget: "https://10.0.0.1:80/",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rp := &Service{
ID: "test-id",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{
TargetId: "peer-1",
TargetType: TargetTypePeer,
Host: tt.host,
Port: tt.port,
Protocol: tt.protocol,
Enabled: true,
},
},
}
pm := rp.ToProtoMapping(Create, "token", oidcConfig)
require.Len(t, pm.Path, 1, "should have one path mapping")
assert.Equal(t, tt.wantTarget, pm.Path[0].Target)
})
}
}
func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) {
rp := &Service{
ID: "test-id",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 8080, Protocol: "http", Enabled: false},
{TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true},
},
}
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
require.Len(t, pm.Path, 1)
assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target)
}
func TestToProtoMapping_OperationTypes(t *testing.T) {
rp := validProxy()
tests := []struct {
op Operation
want proto.ProxyMappingUpdateType
}{
{Create, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED},
{Update, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED},
{Delete, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED},
}
for _, tt := range tests {
t.Run(string(tt.op), func(t *testing.T) {
pm := rp.ToProtoMapping(tt.op, "", proxy.OIDCValidationConfig{})
assert.Equal(t, tt.want, pm.Type)
})
}
}
func TestAuthConfig_HashSecrets(t *testing.T) {
tests := []struct {
name string
config *AuthConfig
wantErr bool
validate func(*testing.T, *AuthConfig)
}{
{
name: "hash password successfully",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "testPassword123",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
t.Errorf("Password not hashed with argon2id, got: %s", config.PasswordAuth.Password)
}
// Verify the hash can be verified
if err := argon2id.Verify("testPassword123", config.PasswordAuth.Password); err != nil {
t.Errorf("Hash verification failed: %v", err)
}
},
},
{
name: "hash PIN successfully",
config: &AuthConfig{
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "123456",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
t.Errorf("PIN not hashed with argon2id, got: %s", config.PinAuth.Pin)
}
// Verify the hash can be verified
if err := argon2id.Verify("123456", config.PinAuth.Pin); err != nil {
t.Errorf("Hash verification failed: %v", err)
}
},
},
{
name: "hash both password and PIN",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "password",
},
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "9999",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
t.Errorf("Password not hashed with argon2id")
}
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
t.Errorf("PIN not hashed with argon2id")
}
if err := argon2id.Verify("password", config.PasswordAuth.Password); err != nil {
t.Errorf("Password hash verification failed: %v", err)
}
if err := argon2id.Verify("9999", config.PinAuth.Pin); err != nil {
t.Errorf("PIN hash verification failed: %v", err)
}
},
},
{
name: "skip disabled password auth",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: false,
Password: "password",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if config.PasswordAuth.Password != "password" {
t.Errorf("Disabled password auth should not be hashed")
}
},
},
{
name: "skip empty password",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if config.PasswordAuth.Password != "" {
t.Errorf("Empty password should remain empty")
}
},
},
{
name: "skip nil password auth",
config: &AuthConfig{
PasswordAuth: nil,
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "1234",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if config.PasswordAuth != nil {
t.Errorf("PasswordAuth should remain nil")
}
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
t.Errorf("PIN should still be hashed")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.HashSecrets()
if (err != nil) != tt.wantErr {
t.Errorf("HashSecrets() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.validate != nil {
tt.validate(t, tt.config)
}
})
}
}
func TestAuthConfig_HashSecrets_VerifyIncorrectSecret(t *testing.T) {
config := &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "correctPassword",
},
}
if err := config.HashSecrets(); err != nil {
t.Fatalf("HashSecrets() error = %v", err)
}
// Verify with wrong password should fail
err := argon2id.Verify("wrongPassword", config.PasswordAuth.Password)
if !errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
t.Errorf("Expected ErrMismatchedHashAndPassword, got %v", err)
}
}
func TestAuthConfig_ClearSecrets(t *testing.T) {
config := &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "hashedPassword",
},
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "hashedPin",
},
}
config.ClearSecrets()
if config.PasswordAuth.Password != "" {
t.Errorf("Password not cleared, got: %s", config.PasswordAuth.Password)
}
if config.PinAuth.Pin != "" {
t.Errorf("PIN not cleared, got: %s", config.PinAuth.Pin)
}
}
func TestGenerateExposeName(t *testing.T) {
t.Run("no prefix generates 12-char name", func(t *testing.T) {
name, err := GenerateExposeName("")
require.NoError(t, err)
assert.Len(t, name, 12)
assert.Regexp(t, `^[a-z0-9]+$`, name)
})
t.Run("with prefix generates prefix-XXXX", func(t *testing.T) {
name, err := GenerateExposeName("myapp")
require.NoError(t, err)
assert.True(t, strings.HasPrefix(name, "myapp-"), "name should start with prefix")
suffix := strings.TrimPrefix(name, "myapp-")
assert.Len(t, suffix, 4, "suffix should be 4 chars")
assert.Regexp(t, `^[a-z0-9]+$`, suffix)
})
t.Run("unique names", func(t *testing.T) {
names := make(map[string]bool)
for i := 0; i < 50; i++ {
name, err := GenerateExposeName("")
require.NoError(t, err)
names[name] = true
}
assert.Greater(t, len(names), 45, "should generate mostly unique names")
})
t.Run("valid prefixes", func(t *testing.T) {
validPrefixes := []string{"a", "ab", "a1", "my-app", "web-server-01", "a-b"}
for _, prefix := range validPrefixes {
name, err := GenerateExposeName(prefix)
assert.NoError(t, err, "prefix %q should be valid", prefix)
assert.True(t, strings.HasPrefix(name, prefix+"-"), "name should start with %q-", prefix)
}
})
t.Run("invalid prefixes", func(t *testing.T) {
invalidPrefixes := []string{
"-starts-with-dash",
"ends-with-dash-",
"has.dots",
"HAS-UPPER",
"has spaces",
"has/slash",
"a--",
}
for _, prefix := range invalidPrefixes {
_, err := GenerateExposeName(prefix)
assert.Error(t, err, "prefix %q should be invalid", prefix)
assert.Contains(t, err.Error(), "invalid name prefix")
}
})
}
func TestExposeServiceRequest_ToService(t *testing.T) {
t.Run("basic HTTP service", func(t *testing.T) {
req := &ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
service := req.ToService("account-1", "peer-1", "mysvc")
assert.Equal(t, "account-1", service.AccountID)
assert.Equal(t, "mysvc", service.Name)
assert.True(t, service.Enabled)
assert.Empty(t, service.Domain, "domain should be empty when not specified")
require.Len(t, service.Targets, 1)
target := service.Targets[0]
assert.Equal(t, 8080, target.Port)
assert.Equal(t, "http", target.Protocol)
assert.Equal(t, "peer-1", target.TargetId)
assert.Equal(t, TargetTypePeer, target.TargetType)
assert.True(t, target.Enabled)
assert.Equal(t, "account-1", target.AccountID)
})
t.Run("with custom domain", func(t *testing.T) {
req := &ExposeServiceRequest{
Port: 3000,
Domain: "example.com",
}
service := req.ToService("acc", "peer", "web")
assert.Equal(t, "web.example.com", service.Domain)
})
t.Run("with PIN auth", func(t *testing.T) {
req := &ExposeServiceRequest{
Port: 80,
Pin: "1234",
}
service := req.ToService("acc", "peer", "svc")
require.NotNil(t, service.Auth.PinAuth)
assert.True(t, service.Auth.PinAuth.Enabled)
assert.Equal(t, "1234", service.Auth.PinAuth.Pin)
assert.Nil(t, service.Auth.PasswordAuth)
assert.Nil(t, service.Auth.BearerAuth)
})
t.Run("with password auth", func(t *testing.T) {
req := &ExposeServiceRequest{
Port: 80,
Password: "secret",
}
service := req.ToService("acc", "peer", "svc")
require.NotNil(t, service.Auth.PasswordAuth)
assert.True(t, service.Auth.PasswordAuth.Enabled)
assert.Equal(t, "secret", service.Auth.PasswordAuth.Password)
})
t.Run("with user groups (bearer auth)", func(t *testing.T) {
req := &ExposeServiceRequest{
Port: 80,
UserGroups: []string{"admins", "devs"},
}
service := req.ToService("acc", "peer", "svc")
require.NotNil(t, service.Auth.BearerAuth)
assert.True(t, service.Auth.BearerAuth.Enabled)
assert.Equal(t, []string{"admins", "devs"}, service.Auth.BearerAuth.DistributionGroups)
})
t.Run("with all auth types", func(t *testing.T) {
req := &ExposeServiceRequest{
Port: 443,
Domain: "myco.com",
Pin: "9999",
Password: "pass",
UserGroups: []string{"ops"},
}
service := req.ToService("acc", "peer", "full")
assert.Equal(t, "full.myco.com", service.Domain)
require.NotNil(t, service.Auth.PinAuth)
require.NotNil(t, service.Auth.PasswordAuth)
require.NotNil(t, service.Auth.BearerAuth)
})
}

View File

@@ -94,7 +94,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
if err != nil {
log.Fatalf("failed to create API handler: %v", err)
}
@@ -134,7 +134,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
if s.Config.HttpConfig.LetsEncryptDomain != "" {
certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain)
if err != nil {
log.Fatalf("failed to create certificate manager: %v", err)
log.Fatalf("failed to create certificate service: %v", err)
}
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
@@ -152,6 +152,11 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
if err != nil {
log.Fatalf("failed to create management server: %v", err)
}
serviceMgr := s.ServiceManager()
srv.SetReverseProxyManager(serviceMgr)
if serviceMgr != nil {
serviceMgr.StartExposeReaper(context.Background())
}
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
mgmtProto.RegisterProxyServiceServer(gRPCAPIHandler, s.ReverseProxyGRPCServer())
@@ -163,9 +168,10 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager())
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
s.AfterInit(func(s *BaseServer) {
proxyService.SetProxyManager(s.ReverseProxyManager())
proxyService.SetServiceManager(s.ServiceManager())
proxyService.SetProxyController(s.ServiceProxyController())
})
return proxyService
})
@@ -188,12 +194,25 @@ func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig {
func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
return Create(s, func() *nbgrpc.OneTimeTokenStore {
tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute)
tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 5*time.Minute, 10*time.Minute, 100)
if err != nil {
log.Fatalf("failed to create proxy token store: %v", err)
}
log.Info("One-time token store initialized for proxy authentication")
return tokenStore
})
}
func (s *BaseServer) PKCEVerifierStore() *nbgrpc.PKCEVerifierStore {
return Create(s, func() *nbgrpc.PKCEVerifierStore {
pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100)
if err != nil {
log.Fatalf("failed to create PKCE verifier store: %v", err)
}
return pkceStore
})
}
func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
return Create(s, func() accesslogs.Manager {
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager())

View File

@@ -6,6 +6,8 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
@@ -106,6 +108,16 @@ func (s *BaseServer) NetworkMapController() network_map.Controller {
})
}
func (s *BaseServer) ServiceProxyController() proxy.Controller {
return Create(s, func() proxy.Controller {
controller, err := proxymanager.NewGRPCController(s.ReverseProxyGRPCServer(), s.Metrics().GetMeter())
if err != nil {
log.Fatalf("failed to create service proxy controller: %v", err)
}
return controller
})
}
func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer {
return Create(s, func() *server.AccountRequestBuffer {
return server.NewAccountRequestBuffer(context.Background(), s.Store())

View File

@@ -8,9 +8,11 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
@@ -99,11 +101,11 @@ func (s *BaseServer) AccountManager() account.Manager {
return Create(s, func() account.Manager {
accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy)
if err != nil {
log.Fatalf("failed to create account manager: %v", err)
log.Fatalf("failed to create account service: %v", err)
}
s.AfterInit(func(s *BaseServer) {
accountManager.SetServiceManager(s.ReverseProxyManager())
accountManager.SetServiceManager(s.ServiceManager())
})
return accountManager
@@ -114,28 +116,28 @@ func (s *BaseServer) IdpManager() idp.Manager {
return Create(s, func() idp.Manager {
var idpManager idp.Manager
var err error
// Use embedded IdP manager if embedded Dex is configured and enabled.
// Use embedded IdP service if embedded Dex is configured and enabled.
// Legacy IdpManager won't be used anymore even if configured.
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
if err != nil {
log.Fatalf("failed to create embedded IDP manager: %v", err)
log.Fatalf("failed to create embedded IDP service: %v", err)
}
return idpManager
}
// Fall back to external IdP manager
// Fall back to external IdP service
if s.Config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
if err != nil {
log.Fatalf("failed to create IDP manager: %v", err)
log.Fatalf("failed to create IDP service: %v", err)
}
}
return idpManager
})
}
// OAuthConfigProvider is only relevant when we have an embedded IdP manager. Otherwise must be nil
// OAuthConfigProvider is only relevant when we have an embedded IdP service. Otherwise must be nil
func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled {
return nil
@@ -162,7 +164,7 @@ func (s *BaseServer) GroupsManager() groups.Manager {
func (s *BaseServer) ResourcesManager() resources.Manager {
return Create(s, func() resources.Manager {
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ReverseProxyManager())
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ServiceManager())
})
}
@@ -190,15 +192,25 @@ func (s *BaseServer) RecordsManager() records.Manager {
})
}
func (s *BaseServer) ReverseProxyManager() reverseproxy.Manager {
return Create(s, func() reverseproxy.Manager {
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), s.ReverseProxyDomainManager())
func (s *BaseServer) ServiceManager() service.Manager {
return Create(s, func() service.Manager {
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ReverseProxyDomainManager())
})
}
func (s *BaseServer) ProxyManager() proxy.Manager {
return Create(s, func() proxy.Manager {
manager, err := proxymanager.NewManager(s.Store(), s.Metrics().GetMeter())
if err != nil {
log.Fatalf("failed to create proxy manager: %v", err)
}
return manager
})
}
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
return Create(s, func() *manager.Manager {
m := manager.NewManager(s.Store(), s.ReverseProxyGRPCServer(), s.PermissionsManager())
m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager(), s.AccountManager())
return &m
})
}

View File

@@ -28,9 +28,13 @@ import (
"github.com/netbirdio/netbird/version"
)
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
// It is used for backward compatibility now.
const ManagementLegacyPort = 33073
const (
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
// It is used for backward compatibility now.
ManagementLegacyPort = 33073
// DefaultSelfHostedDomain is the default domain used for self-hosted fresh installs.
DefaultSelfHostedDomain = "netbird.selfhosted"
)
type Server interface {
Start(ctx context.Context) error
@@ -58,6 +62,7 @@ type BaseServer struct {
mgmtMetricsPort int
mgmtPort int
disableLegacyManagementPort bool
autoResolveDomains bool
proxyAuthClose func()
@@ -81,6 +86,7 @@ type Config struct {
DisableMetrics bool
DisableGeoliteUpdate bool
UserDeleteFromIDPEnabled bool
AutoResolveDomains bool
}
// NewServer initializes and configures a new Server instance
@@ -96,6 +102,7 @@ func NewServer(cfg *Config) *BaseServer {
mgmtPort: cfg.MgmtPort,
disableLegacyManagementPort: cfg.DisableLegacyManagementPort,
mgmtMetricsPort: cfg.MgmtMetricsPort,
autoResolveDomains: cfg.AutoResolveDomains,
}
}
@@ -109,6 +116,10 @@ func (s *BaseServer) Start(ctx context.Context) error {
s.cancel = cancel
s.errCh = make(chan error, 4)
if s.autoResolveDomains {
s.resolveDomains(srvCtx)
}
s.PeersManager()
s.GeoLocationManager()
@@ -157,7 +168,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
// Eagerly create the gRPC server so that all AfterInit hooks are registered
// before we iterate them. Lazy creation after the loop would miss hooks
// registered during GRPCServer() construction (e.g., SetProxyManager).
// registered during GRPCServer() construction (e.g., SetServiceManager).
s.GRPCServer()
for _, fn := range s.afterInit {
@@ -237,7 +248,6 @@ func (s *BaseServer) Stop() error {
_ = s.certManager.Listener().Close()
}
s.GRPCServer().Stop()
s.ReverseProxyGRPCServer().Close()
if s.proxyAuthClose != nil {
s.proxyAuthClose()
s.proxyAuthClose = nil
@@ -381,6 +391,60 @@ func (s *BaseServer) serveGRPCWithHTTP(ctx context.Context, listener net.Listene
}()
}
// resolveDomains determines dnsDomain and mgmtSingleAccModeDomain based on store state.
// Fresh installs use the default self-hosted domain, while existing installs reuse the
// persisted account domain to keep addressing stable across config changes.
func (s *BaseServer) resolveDomains(ctx context.Context) {
st := s.Store()
setDefault := func(logMsg string, args ...any) {
if logMsg != "" {
log.WithContext(ctx).Warnf(logMsg, args...)
}
s.dnsDomain = DefaultSelfHostedDomain
s.mgmtSingleAccModeDomain = DefaultSelfHostedDomain
}
accountsCount, err := st.GetAccountsCounter(ctx)
if err != nil {
setDefault("resolve domains: failed to read accounts counter: %v; using default domain %q", err, DefaultSelfHostedDomain)
return
}
if accountsCount == 0 {
s.dnsDomain = DefaultSelfHostedDomain
s.mgmtSingleAccModeDomain = DefaultSelfHostedDomain
log.WithContext(ctx).Infof("resolve domains: fresh install detected, using default domain %q", DefaultSelfHostedDomain)
return
}
accountID, err := st.GetAnyAccountID(ctx)
if err != nil {
setDefault("resolve domains: failed to get existing account ID: %v; using default domain %q", err, DefaultSelfHostedDomain)
return
}
if accountID == "" {
setDefault("resolve domains: empty account ID returned for existing accounts; using default domain %q", DefaultSelfHostedDomain)
return
}
domain, _, err := st.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
if err != nil {
setDefault("resolve domains: failed to get account domain for account %q: %v; using default domain %q", accountID, err, DefaultSelfHostedDomain)
return
}
if domain == "" {
setDefault("resolve domains: account %q has empty domain; using default domain %q", accountID, DefaultSelfHostedDomain)
return
}
s.dnsDomain = domain
s.mgmtSingleAccModeDomain = domain
log.WithContext(ctx).Infof("resolve domains: using persisted account domain %q", domain)
}
func getInstallationID(ctx context.Context, store store.Store) (string, error) {
installationID := store.GetInstallationID()
if installationID != "" {

View File

@@ -0,0 +1,63 @@
package server
import (
"context"
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/store"
)
func TestResolveDomains_FreshInstallUsesDefault(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(0), nil)
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
Inject[store.Store](srv, mockStore)
srv.resolveDomains(context.Background())
require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain)
require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain)
}
func TestResolveDomains_ExistingInstallUsesPersistedDomain(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(1), nil)
mockStore.EXPECT().GetAnyAccountID(gomock.Any()).Return("acc-1", nil)
mockStore.EXPECT().GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "acc-1").Return("vpn.mycompany.com", "", nil)
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
Inject[store.Store](srv, mockStore)
srv.resolveDomains(context.Background())
require.Equal(t, "vpn.mycompany.com", srv.dnsDomain)
require.Equal(t, "vpn.mycompany.com", srv.mgmtSingleAccModeDomain)
}
func TestResolveDomains_StoreErrorFallsBackToDefault(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(0), errors.New("db failed"))
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
Inject[store.Store](srv, mockStore)
srv.resolveDomains(context.Background())
require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain)
require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain)
}

View File

@@ -0,0 +1,192 @@
package grpc
import (
"context"
pb "github.com/golang/protobuf/proto" // nolint
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/encryption"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbContext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/proto"
internalStatus "github.com/netbirdio/netbird/shared/management/status"
)
// CreateExpose handles a peer request to create a new expose service.
func (s *Server) CreateExpose(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
exposeReq := &proto.ExposeServiceRequest{}
peerKey, err := s.parseRequest(ctx, req, exposeReq)
if err != nil {
return nil, err
}
accountID, peer, err := s.authenticateExposePeer(ctx, peerKey)
if err != nil {
return nil, err
}
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
reverseProxyMgr := s.getReverseProxyManager()
if reverseProxyMgr == nil {
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
}
created, err := reverseProxyMgr.CreateServiceFromPeer(ctx, accountID, peer.ID, &rpservice.ExposeServiceRequest{
NamePrefix: exposeReq.NamePrefix,
Port: int(exposeReq.Port),
Protocol: exposeProtocolToString(exposeReq.Protocol),
Domain: exposeReq.Domain,
Pin: exposeReq.Pin,
Password: exposeReq.Password,
UserGroups: exposeReq.UserGroups,
})
if err != nil {
return nil, mapExposeError(ctx, err)
}
return s.encryptResponse(peerKey, &proto.ExposeServiceResponse{
ServiceName: created.ServiceName,
ServiceUrl: created.ServiceURL,
Domain: created.Domain,
})
}
// RenewExpose extends the TTL of an active expose session.
func (s *Server) RenewExpose(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
renewReq := &proto.RenewExposeRequest{}
peerKey, err := s.parseRequest(ctx, req, renewReq)
if err != nil {
return nil, err
}
accountID, peer, err := s.authenticateExposePeer(ctx, peerKey)
if err != nil {
return nil, err
}
reverseProxyMgr := s.getReverseProxyManager()
if reverseProxyMgr == nil {
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
}
if err := reverseProxyMgr.RenewServiceFromPeer(ctx, accountID, peer.ID, renewReq.Domain); err != nil {
return nil, mapExposeError(ctx, err)
}
return s.encryptResponse(peerKey, &proto.RenewExposeResponse{})
}
// StopExpose terminates an active expose session.
func (s *Server) StopExpose(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
stopReq := &proto.StopExposeRequest{}
peerKey, err := s.parseRequest(ctx, req, stopReq)
if err != nil {
return nil, err
}
accountID, peer, err := s.authenticateExposePeer(ctx, peerKey)
if err != nil {
return nil, err
}
reverseProxyMgr := s.getReverseProxyManager()
if reverseProxyMgr == nil {
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
}
if err := reverseProxyMgr.StopServiceFromPeer(ctx, accountID, peer.ID, stopReq.Domain); err != nil {
return nil, mapExposeError(ctx, err)
}
return s.encryptResponse(peerKey, &proto.StopExposeResponse{})
}
func mapExposeError(ctx context.Context, err error) error {
s, ok := internalStatus.FromError(err)
if !ok {
log.WithContext(ctx).Errorf("expose service error: %v", err)
return status.Errorf(codes.Internal, "internal error")
}
switch s.Type() {
case internalStatus.InvalidArgument:
return status.Errorf(codes.InvalidArgument, "%s", s.Message)
case internalStatus.PermissionDenied:
return status.Errorf(codes.PermissionDenied, "%s", s.Message)
case internalStatus.NotFound:
return status.Errorf(codes.NotFound, "%s", s.Message)
case internalStatus.AlreadyExists:
return status.Errorf(codes.AlreadyExists, "%s", s.Message)
case internalStatus.PreconditionFailed:
return status.Errorf(codes.ResourceExhausted, "%s", s.Message)
default:
log.WithContext(ctx).Errorf("expose service error: %v", err)
return status.Errorf(codes.Internal, "internal error")
}
}
func (s *Server) encryptResponse(peerKey wgtypes.Key, msg pb.Message) (*proto.EncryptedMessage, error) {
wgKey, err := s.secretsManager.GetWGKey()
if err != nil {
return nil, status.Errorf(codes.Internal, "internal error")
}
encryptedResp, err := encryption.EncryptMessage(peerKey, wgKey, msg)
if err != nil {
return nil, status.Errorf(codes.Internal, "encrypt response")
}
return &proto.EncryptedMessage{
WgPubKey: wgKey.PublicKey().String(),
Body: encryptedResp,
}, nil
}
func (s *Server) authenticateExposePeer(ctx context.Context, peerKey wgtypes.Key) (string, *nbpeer.Peer, error) {
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
return "", nil, status.Errorf(codes.PermissionDenied, "peer is not registered")
}
return "", nil, status.Errorf(codes.Internal, "lookup account for peer")
}
peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerKey.String())
if err != nil {
return "", nil, status.Errorf(codes.PermissionDenied, "peer is not registered")
}
return accountID, peer, nil
}
func (s *Server) getReverseProxyManager() rpservice.Manager {
s.reverseProxyMu.RLock()
defer s.reverseProxyMu.RUnlock()
return s.reverseProxyManager
}
// SetReverseProxyManager sets the reverse proxy manager on the server.
func (s *Server) SetReverseProxyManager(mgr rpservice.Manager) {
s.reverseProxyMu.Lock()
defer s.reverseProxyMu.Unlock()
s.reverseProxyManager = mgr
}
func exposeProtocolToString(p proto.ExposeProtocol) string {
switch p {
case proto.ExposeProtocol_EXPOSE_HTTP:
return "http"
case proto.ExposeProtocol_EXPOSE_HTTPS:
return "https"
default:
return "http"
}
}

View File

@@ -1,28 +1,23 @@
package grpc
import (
"context"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"sync"
"time"
"github.com/eko/gocache/lib/v4/cache"
"github.com/eko/gocache/lib/v4/store"
log "github.com/sirupsen/logrus"
nbcache "github.com/netbirdio/netbird/management/server/cache"
)
// OneTimeTokenStore manages short-lived, single-use authentication tokens
// for proxy-to-management RPC authentication. Tokens are generated when
// a service is created and must be used exactly once by the proxy
// to authenticate a subsequent RPC call.
type OneTimeTokenStore struct {
tokens map[string]*tokenMetadata
mu sync.RWMutex
cleanup *time.Ticker
cleanupDone chan struct{}
}
// tokenMetadata stores information about a one-time token
type tokenMetadata struct {
ServiceID string
AccountID string
@@ -30,20 +25,24 @@ type tokenMetadata struct {
CreatedAt time.Time
}
// NewOneTimeTokenStore creates a new token store with automatic cleanup
// of expired tokens. The cleanupInterval determines how often expired
// tokens are removed from memory.
func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
store := &OneTimeTokenStore{
tokens: make(map[string]*tokenMetadata),
cleanup: time.NewTicker(cleanupInterval),
cleanupDone: make(chan struct{}),
// OneTimeTokenStore manages single-use authentication tokens for proxy-to-management RPC.
// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var.
type OneTimeTokenStore struct {
cache *cache.Cache[string]
ctx context.Context
}
// NewOneTimeTokenStore creates a token store with automatic backend selection
func NewOneTimeTokenStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*OneTimeTokenStore, error) {
cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn)
if err != nil {
return nil, fmt.Errorf("failed to create cache store: %w", err)
}
// Start background cleanup goroutine
go store.cleanupExpired()
return store
return &OneTimeTokenStore{
cache: cache.New[string](cacheStore),
ctx: ctx,
}, nil
}
// GenerateToken creates a new cryptographically secure one-time token
@@ -52,25 +51,30 @@ func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
//
// Returns the generated token string or an error if random generation fails.
func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) {
// Generate 32 bytes (256 bits) of cryptographically secure random data
randomBytes := make([]byte, 32)
if _, err := rand.Read(randomBytes); err != nil {
return "", fmt.Errorf("failed to generate random token: %w", err)
}
// Encode as URL-safe base64 for easy transmission in gRPC
token := base64.URLEncoding.EncodeToString(randomBytes)
hashedToken := hashToken(token)
s.mu.Lock()
defer s.mu.Unlock()
s.tokens[token] = &tokenMetadata{
metadata := &tokenMetadata{
ServiceID: serviceID,
AccountID: accountID,
ExpiresAt: time.Now().Add(ttl),
CreatedAt: time.Now(),
}
metadataJSON, err := json.Marshal(metadata)
if err != nil {
return "", fmt.Errorf("failed to serialize token metadata: %w", err)
}
if err := s.cache.Set(s.ctx, hashedToken, string(metadataJSON), store.WithExpiration(ttl)); err != nil {
return "", fmt.Errorf("failed to store token: %w", err)
}
log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)",
serviceID, accountID, ttl)
@@ -88,80 +92,45 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.
// - Account ID doesn't match
// - Reverse proxy ID doesn't match
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error {
s.mu.Lock()
defer s.mu.Unlock()
hashedToken := hashToken(token)
metadata, exists := s.tokens[token]
if !exists {
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)",
serviceID, accountID)
metadataJSON, err := s.cache.Get(s.ctx, hashedToken)
if err != nil {
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)", serviceID, accountID)
return fmt.Errorf("invalid token")
}
// Check expiration
metadata := &tokenMetadata{}
if err := json.Unmarshal([]byte(metadataJSON), metadata); err != nil {
log.Warnf("Token validation failed: failed to unmarshal metadata (proxy: %s, account: %s): %v", serviceID, accountID, err)
return fmt.Errorf("invalid token metadata")
}
if time.Now().After(metadata.ExpiresAt) {
delete(s.tokens, token)
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)",
serviceID, accountID)
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)", serviceID, accountID)
return fmt.Errorf("token expired")
}
// Validate account ID using constant-time comparison (prevents timing attacks)
if subtle.ConstantTimeCompare([]byte(metadata.AccountID), []byte(accountID)) != 1 {
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)",
metadata.AccountID, accountID)
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)", metadata.AccountID, accountID)
return fmt.Errorf("account ID mismatch")
}
// Validate service ID using constant-time comparison
if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 {
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)",
metadata.ServiceID, serviceID)
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)", metadata.ServiceID, serviceID)
return fmt.Errorf("service ID mismatch")
}
// Delete token immediately to enforce single-use
delete(s.tokens, token)
if err := s.cache.Delete(s.ctx, hashedToken); err != nil {
log.Warnf("Token deletion warning (proxy: %s, account: %s): %v", serviceID, accountID, err)
}
log.Infof("Token validated and consumed for proxy %s in account %s",
serviceID, accountID)
log.Infof("Token validated and consumed for proxy %s in account %s", serviceID, accountID)
return nil
}
// cleanupExpired removes expired tokens in the background to prevent memory leaks
func (s *OneTimeTokenStore) cleanupExpired() {
for {
select {
case <-s.cleanup.C:
s.mu.Lock()
now := time.Now()
removed := 0
for token, metadata := range s.tokens {
if now.After(metadata.ExpiresAt) {
delete(s.tokens, token)
removed++
}
}
if removed > 0 {
log.Debugf("Cleaned up %d expired one-time tokens", removed)
}
s.mu.Unlock()
case <-s.cleanupDone:
return
}
}
}
// Close stops the cleanup goroutine and releases resources
func (s *OneTimeTokenStore) Close() {
s.cleanup.Stop()
close(s.cleanupDone)
}
// GetTokenCount returns the current number of tokens in the store (for debugging/metrics)
func (s *OneTimeTokenStore) GetTokenCount() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.tokens)
func hashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}

View File

@@ -0,0 +1,61 @@
package grpc
import (
"context"
"fmt"
"time"
"github.com/eko/gocache/lib/v4/cache"
"github.com/eko/gocache/lib/v4/store"
log "github.com/sirupsen/logrus"
nbcache "github.com/netbirdio/netbird/management/server/cache"
)
// PKCEVerifierStore manages PKCE verifiers for OAuth flows.
// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var.
type PKCEVerifierStore struct {
cache *cache.Cache[string]
ctx context.Context
}
// NewPKCEVerifierStore creates a PKCE verifier store with automatic backend selection
func NewPKCEVerifierStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*PKCEVerifierStore, error) {
cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn)
if err != nil {
return nil, fmt.Errorf("failed to create cache store: %w", err)
}
return &PKCEVerifierStore{
cache: cache.New[string](cacheStore),
ctx: ctx,
}, nil
}
// Store saves a PKCE verifier associated with an OAuth state parameter.
// The verifier is stored with the specified TTL and will be automatically deleted after expiration.
func (s *PKCEVerifierStore) Store(state, verifier string, ttl time.Duration) error {
if err := s.cache.Set(s.ctx, state, verifier, store.WithExpiration(ttl)); err != nil {
return fmt.Errorf("failed to store PKCE verifier: %w", err)
}
log.Debugf("Stored PKCE verifier for state (expires in %s)", ttl)
return nil
}
// LoadAndDelete retrieves and removes a PKCE verifier for the given state.
// Returns the verifier and true if found, or empty string and false if not found.
// This enforces single-use semantics for PKCE verifiers.
func (s *PKCEVerifierStore) LoadAndDelete(state string) (string, bool) {
verifier, err := s.cache.Get(s.ctx, state)
if err != nil {
log.Debugf("PKCE verifier not found for state")
return "", false
}
if err := s.cache.Delete(s.ctx, state); err != nil {
log.Warnf("Failed to delete PKCE verifier for state: %v", err)
}
return verifier, true
}

View File

@@ -18,14 +18,14 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
@@ -58,17 +58,17 @@ type ProxyServiceServer struct {
// Map of connected proxies: proxy_id -> proxy connection
connectedProxies sync.Map
// Map of cluster address -> set of proxy IDs
clusterProxies sync.Map
// Channel for broadcasting reverse proxy updates to all proxies
updatesChan chan *proto.ProxyMapping
// Manager for access logs
accessLogManager accesslogs.Manager
// Manager for reverse proxy operations
reverseProxyManager reverseproxy.Manager
serviceManager rpservice.Manager
// ProxyController for service updates and cluster management
proxyController proxy.Controller
// Manager for proxy connections
proxyManager proxy.Manager
// Manager for peers
peersManager peers.Manager
@@ -82,84 +82,67 @@ type ProxyServiceServer struct {
// OIDC configuration for proxy authentication
oidcConfig ProxyOIDCConfig
// TODO: use database to store these instead?
// pkceVerifiers stores PKCE code verifiers keyed by OAuth state.
// Entries expire after pkceVerifierTTL to prevent unbounded growth.
pkceVerifiers sync.Map
pkceCleanupCancel context.CancelFunc
// Store for PKCE verifiers
pkceVerifierStore *PKCEVerifierStore
}
const pkceVerifierTTL = 10 * time.Minute
type pkceEntry struct {
verifier string
createdAt time.Time
}
// proxyConnection represents a connected proxy
type proxyConnection struct {
proxyID string
address string
stream proto.ProxyService_GetMappingUpdateServer
sendChan chan *proto.ProxyMapping
sendChan chan *proto.GetMappingUpdateResponse
ctx context.Context
cancel context.CancelFunc
}
// NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager) *ProxyServiceServer {
ctx, cancel := context.WithCancel(context.Background())
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
ctx := context.Background()
s := &ProxyServiceServer{
updatesChan: make(chan *proto.ProxyMapping, 100),
accessLogManager: accessLogMgr,
oidcConfig: oidcConfig,
tokenStore: tokenStore,
pkceVerifierStore: pkceStore,
peersManager: peersManager,
usersManager: usersManager,
pkceCleanupCancel: cancel,
proxyManager: proxyMgr,
}
go s.cleanupPKCEVerifiers(ctx)
go s.cleanupStaleProxies(ctx)
return s
}
// cleanupPKCEVerifiers periodically removes expired PKCE verifiers.
func (s *ProxyServiceServer) cleanupPKCEVerifiers(ctx context.Context) {
ticker := time.NewTicker(pkceVerifierTTL)
// cleanupStaleProxies periodically removes proxies that haven't sent heartbeat in 10 minutes
func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
now := time.Now()
s.pkceVerifiers.Range(func(key, value any) bool {
if entry, ok := value.(pkceEntry); ok && now.Sub(entry.createdAt) > pkceVerifierTTL {
s.pkceVerifiers.Delete(key)
}
return true
})
if err := s.proxyManager.CleanupStale(ctx, 10*time.Minute); err != nil {
log.WithContext(ctx).Debugf("Failed to cleanup stale proxies: %v", err)
}
}
}
}
// Close stops background goroutines.
func (s *ProxyServiceServer) Close() {
s.pkceCleanupCancel()
func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) {
s.serviceManager = manager
}
func (s *ProxyServiceServer) SetProxyManager(manager reverseproxy.Manager) {
s.reverseProxyManager = manager
func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller) {
s.proxyController = proxyController
}
// GetMappingUpdate handles the control stream with proxy clients
func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error {
ctx := stream.Context()
peerInfo := ""
if p, ok := peer.FromContext(ctx); ok {
peerInfo = p.Addr.String()
}
peerInfo := PeerIPFromContext(ctx)
log.Infof("New proxy connection from %s", peerInfo)
proxyID := req.GetProxyId()
@@ -177,13 +160,21 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
proxyID: proxyID,
address: proxyAddress,
stream: stream,
sendChan: make(chan *proto.ProxyMapping, 100),
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
ctx: connCtx,
cancel: cancel,
}
s.connectedProxies.Store(proxyID, conn)
s.addToCluster(conn.address, proxyID)
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
// Register proxy in database
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err)
}
log.WithFields(log.Fields{
"proxy_id": proxyID,
"address": proxyAddress,
@@ -191,8 +182,15 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
"total_proxies": len(s.GetConnectedProxies()),
}).Info("Proxy registered in cluster")
defer func() {
if err := s.proxyManager.Disconnect(context.Background(), proxyID); err != nil {
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
}
s.connectedProxies.Delete(proxyID)
s.removeFromCluster(conn.address, proxyID)
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
}
cancel()
log.Infof("Proxy %s disconnected", proxyID)
}()
@@ -204,6 +202,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
errChan := make(chan error, 2)
go s.sender(conn, errChan)
// Start heartbeat goroutine
go s.heartbeat(connCtx, proxyID)
select {
case err := <-errChan:
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
@@ -212,10 +213,27 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
}
}
// heartbeat updates the proxy's last_seen timestamp every minute
func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID string) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := s.proxyManager.Heartbeat(ctx, proxyID); err != nil {
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err)
}
case <-ctx.Done():
return
}
}
}
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
// Only services matching the proxy's cluster address are sent.
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil {
return fmt.Errorf("get services from store: %w", err)
}
@@ -224,7 +242,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
return fmt.Errorf("proxy address is invalid")
}
var filtered []*reverseproxy.Service
var filtered []*rpservice.Service
for _, service := range services {
if !service.Enabled {
continue
@@ -259,7 +277,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{
service.ToProtoMapping(
reverseproxy.Create, // Initial snapshot, all records are "new" for the proxy.
rpservice.Create, // Initial snapshot, all records are "new" for the proxy.
token,
s.GetOIDCValidationConfig(),
),
@@ -288,7 +306,7 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
for {
select {
case msg := <-conn.sendChan:
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{Mapping: []*proto.ProxyMapping{msg}}); err != nil {
if err := conn.stream.Send(msg); err != nil {
errChan <- err
return
}
@@ -339,7 +357,7 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA
// Management should call this when services are created/updated/removed.
// For create/update operations a unique one-time auth token is generated per
// proxy so that every replica can independently authenticate with management.
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.ProxyMapping) {
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) {
log.Debugf("Broadcasting service update to all connected proxy servers")
s.connectedProxies.Range(func(key, value interface{}) bool {
conn := value.(*proxyConnection)
@@ -349,7 +367,7 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.ProxyMapping) {
}
select {
case conn.sendChan <- msg:
log.Debugf("Sent service update with id %s to proxy server %s", update.Id, conn.proxyID)
log.Debugf("Sent service update to proxy server %s", conn.proxyID)
default:
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
}
@@ -393,81 +411,75 @@ func (s *ProxyServiceServer) GetConnectedProxyURLs() []string {
return urls
}
// addToCluster registers a proxy in a cluster.
func (s *ProxyServiceServer) addToCluster(clusterAddr, proxyID string) {
if clusterAddr == "" {
return
}
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
proxySet.(*sync.Map).Store(proxyID, struct{}{})
log.Debugf("Added proxy %s to cluster %s", proxyID, clusterAddr)
}
// removeFromCluster removes a proxy from a cluster.
func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) {
if clusterAddr == "" {
return
}
if proxySet, ok := s.clusterProxies.Load(clusterAddr); ok {
proxySet.(*sync.Map).Delete(proxyID)
log.Debugf("Removed proxy %s from cluster %s", proxyID, clusterAddr)
}
}
// SendServiceUpdateToCluster sends a service update to all proxy servers in a specific cluster.
// If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility).
// For create/update operations a unique one-time auth token is generated per
// proxy so that every replica can independently authenticate with management.
func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) {
func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, update *proto.ProxyMapping, clusterAddr string) {
updateResponse := &proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{update},
}
if clusterAddr == "" {
s.SendServiceUpdate(update)
s.SendServiceUpdate(updateResponse)
return
}
proxySet, ok := s.clusterProxies.Load(clusterAddr)
if !ok {
log.Debugf("No proxies connected for cluster %s", clusterAddr)
if s.proxyController == nil {
log.WithContext(ctx).Debugf("ProxyController not set, cannot send to cluster %s", clusterAddr)
return
}
proxyIDs := s.proxyController.GetProxiesForCluster(clusterAddr)
if len(proxyIDs) == 0 {
log.WithContext(ctx).Debugf("No proxies connected for cluster %s", clusterAddr)
return
}
log.Debugf("Sending service update to cluster %s", clusterAddr)
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
proxyID := key.(string)
for _, proxyID := range proxyIDs {
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
conn := connVal.(*proxyConnection)
msg := s.perProxyMessage(update, proxyID)
msg := s.perProxyMessage(updateResponse, proxyID)
if msg == nil {
return true
continue
}
select {
case conn.sendChan <- msg:
log.Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
default:
log.Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
}
}
return true
})
}
}
// perProxyMessage returns a copy of update with a fresh one-time token for
// create/update operations. For delete operations the original message is
// returned unchanged because proxies do not need to authenticate for removal.
// create/update operations. For delete operations the original mapping is
// used unchanged because proxies do not need to authenticate for removal.
// Returns nil if token generation fails (the proxy should be skipped).
func (s *ProxyServiceServer) perProxyMessage(update *proto.ProxyMapping, proxyID string) *proto.ProxyMapping {
if update.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED || update.AccountId == "" {
return update
func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse {
resp := make([]*proto.ProxyMapping, 0, len(update.Mapping))
for _, mapping := range update.Mapping {
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
resp = append(resp, mapping)
continue
}
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute)
if err != nil {
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
return nil
}
msg := shallowCloneMapping(mapping)
msg.AuthToken = token
resp = append(resp, msg)
}
token, err := s.tokenStore.GenerateToken(update.AccountId, update.Id, 5*time.Minute)
if err != nil {
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
return nil
return &proto.GetMappingUpdateResponse{
Mapping: resp,
}
msg := shallowCloneMapping(update)
msg.AuthToken = token
return msg
}
// shallowCloneMapping creates a shallow copy of a ProxyMapping, reusing the
@@ -486,35 +498,8 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
}
}
// GetAvailableClusters returns information about all connected proxy clusters.
func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo {
clusterCounts := make(map[string]int)
s.clusterProxies.Range(func(key, value interface{}) bool {
clusterAddr := key.(string)
proxySet := value.(*sync.Map)
count := 0
proxySet.Range(func(_, _ interface{}) bool {
count++
return true
})
if count > 0 {
clusterCounts[clusterAddr] = count
}
return true
})
clusters := make([]ClusterInfo, 0, len(clusterCounts))
for addr, count := range clusterCounts {
clusters = append(clusters, ClusterInfo{
Address: addr,
ConnectedProxies: count,
})
}
return clusters
}
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
service, err := s.reverseProxyManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
if err != nil {
log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
return nil, status.Errorf(codes.FailedPrecondition, "get service from store: %v", err)
@@ -533,7 +518,7 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
}, nil
}
func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *reverseproxy.Service) (bool, string, proxyauth.Method) {
func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *rpservice.Service) (bool, string, proxyauth.Method) {
switch v := req.GetRequest().(type) {
case *proto.AuthenticateRequest_Pin:
return s.authenticatePIN(ctx, req.GetId(), v, service.Auth.PinAuth)
@@ -544,7 +529,7 @@ func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto
}
}
func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *reverseproxy.PINAuthConfig) (bool, string, proxyauth.Method) {
func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *rpservice.PINAuthConfig) (bool, string, proxyauth.Method) {
if auth == nil || !auth.Enabled {
log.WithContext(ctx).Debugf("PIN authentication attempted but not enabled for service %s", serviceID)
return false, "", ""
@@ -558,7 +543,7 @@ func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID stri
return true, "pin-user", proxyauth.MethodPIN
}
func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *reverseproxy.PasswordAuthConfig) (bool, string, proxyauth.Method) {
func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *rpservice.PasswordAuthConfig) (bool, string, proxyauth.Method) {
if auth == nil || !auth.Enabled {
log.WithContext(ctx).Debugf("password authentication attempted but not enabled for service %s", serviceID)
return false, "", ""
@@ -580,7 +565,7 @@ func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err err
}
}
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *reverseproxy.Service, userId string, method proxyauth.Method) (string, error) {
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId string, method proxyauth.Method) (string, error) {
if !authenticated || service.SessionPrivateKey == "" {
return "", nil
}
@@ -620,7 +605,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
}
if certificateIssued {
if err := s.reverseProxyManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil {
if err := s.serviceManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil {
log.WithContext(ctx).WithError(err).Error("failed to set certificate issued timestamp")
return nil, status.Errorf(codes.Internal, "update certificate timestamp: %v", err)
}
@@ -632,7 +617,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
internalStatus := protoStatusToInternal(protoStatus)
if err := s.reverseProxyManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
log.WithContext(ctx).WithError(err).Error("failed to update service status")
return nil, status.Errorf(codes.Internal, "update service status: %v", err)
}
@@ -647,22 +632,22 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
}
// protoStatusToInternal maps proto status to internal status
func protoStatusToInternal(protoStatus proto.ProxyStatus) reverseproxy.ProxyStatus {
func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status {
switch protoStatus {
case proto.ProxyStatus_PROXY_STATUS_PENDING:
return reverseproxy.StatusPending
return rpservice.StatusPending
case proto.ProxyStatus_PROXY_STATUS_ACTIVE:
return reverseproxy.StatusActive
return rpservice.StatusActive
case proto.ProxyStatus_PROXY_STATUS_TUNNEL_NOT_CREATED:
return reverseproxy.StatusTunnelNotCreated
return rpservice.StatusTunnelNotCreated
case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_PENDING:
return reverseproxy.StatusCertificatePending
return rpservice.StatusCertificatePending
case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_FAILED:
return reverseproxy.StatusCertificateFailed
return rpservice.StatusCertificateFailed
case proto.ProxyStatus_PROXY_STATUS_ERROR:
return reverseproxy.StatusError
return rpservice.StatusError
default:
return reverseproxy.StatusError
return rpservice.StatusError
}
}
@@ -727,7 +712,7 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
}
// Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection.
services, err := s.reverseProxyManager.GetAccountServices(ctx, req.GetAccountId())
services, err := s.serviceManager.GetAccountServices(ctx, req.GetAccountId())
if err != nil {
log.WithContext(ctx).Errorf("failed to get account services: %v", err)
return nil, status.Errorf(codes.FailedPrecondition, "get account services: %v", err)
@@ -771,7 +756,10 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
state := fmt.Sprintf("%s|%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), nonceB64, hmacSum)
codeVerifier := oauth2.GenerateVerifier()
s.pkceVerifiers.Store(state, pkceEntry{verifier: codeVerifier, createdAt: time.Now()})
if err := s.pkceVerifierStore.Store(state, codeVerifier, pkceVerifierTTL); err != nil {
log.WithContext(ctx).Errorf("failed to store PKCE verifier: %v", err)
return nil, status.Errorf(codes.Internal, "store PKCE verifier: %v", err)
}
return &proto.GetOIDCURLResponse{
Url: (&oauth2.Config{
@@ -790,8 +778,8 @@ func (s *ProxyServiceServer) GetOIDCConfig() ProxyOIDCConfig {
// GetOIDCValidationConfig returns the OIDC configuration for token validation
// in the format needed by ToProtoMapping.
func (s *ProxyServiceServer) GetOIDCValidationConfig() reverseproxy.OIDCValidationConfig {
return reverseproxy.OIDCValidationConfig{
func (s *ProxyServiceServer) GetOIDCValidationConfig() proxy.OIDCValidationConfig {
return proxy.OIDCValidationConfig{
Issuer: s.oidcConfig.Issuer,
Audiences: []string{s.oidcConfig.Audience},
KeysLocation: s.oidcConfig.KeysLocation,
@@ -808,18 +796,10 @@ func (s *ProxyServiceServer) generateHMAC(input string) string {
// ValidateState validates the state parameter from an OAuth callback.
// Returns the original redirect URL if valid, or an error if invalid.
func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL string, err error) {
v, ok := s.pkceVerifiers.LoadAndDelete(state)
verifier, ok := s.pkceVerifierStore.LoadAndDelete(state)
if !ok {
return "", "", errors.New("no verifier for state")
}
entry, ok := v.(pkceEntry)
if !ok {
return "", "", errors.New("invalid verifier for state")
}
if time.Since(entry.createdAt) > pkceVerifierTTL {
return "", "", errors.New("PKCE verifier expired")
}
verifier = entry.verifier
// State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce)
parts := strings.Split(state, "|")
@@ -850,12 +830,12 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
// GenerateSessionToken creates a signed session JWT for the given domain and user.
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
// Find the service by domain to get its signing key
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil {
return "", fmt.Errorf("get services: %w", err)
}
var service *reverseproxy.Service
var service *rpservice.Service
for _, svc := range services {
if svc.Domain == domain {
service = svc
@@ -921,8 +901,8 @@ func (s *ProxyServiceServer) ValidateUserGroupAccess(ctx context.Context, domain
return fmt.Errorf("user %s not in allowed groups for domain %s", user.Id, domain)
}
func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
services, err := s.reverseProxyManager.GetAccountServices(ctx, accountID)
func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) {
services, err := s.serviceManager.GetAccountServices(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get account services: %w", err)
}
@@ -1043,8 +1023,8 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil
}
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*reverseproxy.Service, error) {
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil {
return nil, fmt.Errorf("get services: %w", err)
}
@@ -1058,7 +1038,7 @@ func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain stri
return nil, fmt.Errorf("service not found for domain: %s", domain)
}
func (s *ProxyServiceServer) checkGroupAccess(service *reverseproxy.Service, user *types.User) error {
func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {
if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled {
return nil
}

View File

@@ -107,7 +107,7 @@ func NewProxyAuthInterceptors(tokenStore proxyTokenStore) (grpc.UnaryServerInter
}
func (i *proxyAuthInterceptor) validateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
clientIP := peerIPFromContext(ctx)
clientIP := PeerIPFromContext(ctx)
if clientIP != "" && i.failureLimiter.isLimited(clientIP) {
return nil, status.Errorf(codes.ResourceExhausted, "too many failed authentication attempts")

View File

@@ -115,9 +115,9 @@ func (l *authFailureLimiter) stop() {
l.cancel()
}
// peerIPFromContext extracts the client IP from the gRPC context.
// PeerIPFromContext extracts the client IP from the gRPC context.
// Uses realip (from trusted proxy headers) first, falls back to the transport peer address.
func peerIPFromContext(ctx context.Context) clientIP {
func PeerIPFromContext(ctx context.Context) string {
if addr, ok := realip.FromContext(ctx); ok {
return addr.String()
}

View File

@@ -8,40 +8,44 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/types"
)
type mockReverseProxyManager struct {
proxiesByAccount map[string][]*reverseproxy.Service
proxiesByAccount map[string][]*service.Service
err error
}
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
func (m *mockReverseProxyManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
return nil
}
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
if m.err != nil {
return nil, m.err
}
return m.proxiesByAccount[accountID], nil
}
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
return nil, nil
}
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
return []*reverseproxy.Service{}, nil
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
return []*service.Service{}, nil
}
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*service.Service, error) {
return &service.Service{}, nil
}
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *service.Service) (*service.Service, error) {
return &service.Service{}, nil
}
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *service.Service) (*service.Service, error) {
return &service.Service{}, nil
}
func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error {
@@ -52,7 +56,7 @@ func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, ac
return nil
}
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error {
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status service.Status) error {
return nil
}
@@ -64,14 +68,28 @@ func (m *mockReverseProxyManager) ReloadService(ctx context.Context, accountID,
return nil
}
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*service.Service, error) {
return &service.Service{}, nil
}
func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil
}
func (m *mockReverseProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
return &service.ExposeServiceResponse{}, nil
}
func (m *mockReverseProxyManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {}
type mockUsersManager struct {
users map[string]*types.User
err error
@@ -93,7 +111,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name string
domain string
userID string
proxiesByAccount map[string][]*reverseproxy.Service
proxiesByAccount map[string][]*service.Service
users map[string]*types.User
proxyErr error
userErr error
@@ -104,7 +122,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user not found",
domain: "app.example.com",
userID: "unknown-user",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
},
users: map[string]*types.User{},
@@ -115,7 +133,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "proxy not found in user's account",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{},
proxiesByAccount: map[string][]*service.Service{},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
@@ -126,7 +144,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "proxy exists in different account - not accessible",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account2": {{Domain: "app.example.com", AccountID: "account2"}},
},
users: map[string]*types.User{
@@ -139,8 +157,8 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "no bearer auth configured - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}},
proxiesByAccount: map[string][]*service.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: service.AuthConfig{}}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
@@ -151,12 +169,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "bearer auth disabled - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: false},
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{Enabled: false},
},
}},
},
@@ -169,12 +187,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "bearer auth enabled but no groups configured - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{},
},
@@ -190,12 +208,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user not in allowed groups",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
@@ -212,12 +230,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user in one of the allowed groups - allow access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
@@ -233,12 +251,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user in all allowed groups - allow access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
@@ -266,10 +284,10 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "multiple proxies in account - finds correct one",
domain: "app2.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {
{Domain: "app1.example.com", AccountID: "account1"},
{Domain: "app2.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}},
{Domain: "app2.example.com", AccountID: "account1", Auth: service.AuthConfig{}},
{Domain: "app3.example.com", AccountID: "account1"},
},
},
@@ -283,7 +301,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := &ProxyServiceServer{
reverseProxyManager: &mockReverseProxyManager{
serviceManager: &mockReverseProxyManager{
proxiesByAccount: tt.proxiesByAccount,
err: tt.proxyErr,
},
@@ -310,7 +328,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name string
accountID string
domain string
proxiesByAccount map[string][]*reverseproxy.Service
proxiesByAccount map[string][]*service.Service
err error
expectProxy bool
expectErr bool
@@ -319,7 +337,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name: "proxy found",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {
{Domain: "other.example.com", AccountID: "account1"},
{Domain: "app.example.com", AccountID: "account1"},
@@ -332,7 +350,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name: "proxy not found in account",
accountID: "account1",
domain: "unknown.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
},
expectProxy: false,
@@ -342,7 +360,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name: "empty proxy list for account",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{},
proxiesByAccount: map[string][]*service.Service{},
expectProxy: false,
expectErr: true,
},
@@ -360,7 +378,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := &ProxyServiceServer{
reverseProxyManager: &mockReverseProxyManager{
serviceManager: &mockReverseProxyManager{
proxiesByAccount: tt.proxiesByAccount,
err: tt.err,
},

View File

@@ -1,6 +1,7 @@
package grpc
import (
"context"
"crypto/rand"
"encoding/base64"
"strings"
@@ -11,13 +12,65 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/shared/management/proto"
)
type testProxyController struct {
mu sync.Mutex
clusterProxies map[string]map[string]struct{}
}
func newTestProxyController() *testProxyController {
return &testProxyController{
clusterProxies: make(map[string]map[string]struct{}),
}
}
func (c *testProxyController) SendServiceUpdateToCluster(_ context.Context, _ string, _ *proto.ProxyMapping, _ string) {
}
func (c *testProxyController) GetOIDCValidationConfig() proxy.OIDCValidationConfig {
return proxy.OIDCValidationConfig{}
}
func (c *testProxyController) RegisterProxyToCluster(_ context.Context, clusterAddr, proxyID string) error {
c.mu.Lock()
defer c.mu.Unlock()
if _, ok := c.clusterProxies[clusterAddr]; !ok {
c.clusterProxies[clusterAddr] = make(map[string]struct{})
}
c.clusterProxies[clusterAddr][proxyID] = struct{}{}
return nil
}
func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, clusterAddr, proxyID string) error {
c.mu.Lock()
defer c.mu.Unlock()
if proxies, ok := c.clusterProxies[clusterAddr]; ok {
delete(proxies, proxyID)
}
return nil
}
func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string {
c.mu.Lock()
defer c.mu.Unlock()
proxies, ok := c.clusterProxies[clusterAddr]
if !ok {
return nil
}
result := make([]string, 0, len(proxies))
for id := range proxies {
result = append(result, id)
}
return result
}
// registerFakeProxy adds a fake proxy connection to the server's internal maps
// and returns the channel where messages will be received.
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.ProxyMapping {
ch := make(chan *proto.ProxyMapping, 10)
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.GetMappingUpdateResponse {
ch := make(chan *proto.GetMappingUpdateResponse, 10)
conn := &proxyConnection{
proxyID: proxyID,
address: clusterAddr,
@@ -25,13 +78,12 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
}
s.connectedProxies.Store(proxyID, conn)
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
proxySet.(*sync.Map).Store(proxyID, struct{}{})
_ = s.proxyController.RegisterProxyToCluster(context.Background(), clusterAddr, proxyID)
return ch
}
func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping {
func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse {
select {
case msg := <-ch:
return msg
@@ -41,24 +93,29 @@ func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping {
}
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour)
defer tokenStore.Close()
ctx := context.Background()
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{
tokenStore: tokenStore,
updatesChan: make(chan *proto.ProxyMapping, 100),
tokenStore: tokenStore,
pkceVerifierStore: pkceStore,
}
s.SetProxyController(newTestProxyController())
const cluster = "proxy.example.com"
const numProxies = 3
channels := make([]chan *proto.ProxyMapping, numProxies)
channels := make([]chan *proto.GetMappingUpdateResponse, numProxies)
for i := range numProxies {
id := "proxy-" + string(rune('a'+i))
channels[i] = registerFakeProxy(s, id, cluster)
}
update := &proto.ProxyMapping{
mapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-1",
AccountId: "account-1",
@@ -68,14 +125,16 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
},
}
s.SendServiceUpdateToCluster(update, cluster)
s.SendServiceUpdateToCluster(context.Background(), mapping, cluster)
tokens := make([]string, numProxies)
for i, ch := range channels {
msg := drainChannel(ch)
require.NotNil(t, msg, "proxy %d should receive a message", i)
assert.Equal(t, update.Domain, msg.Domain)
assert.Equal(t, update.Id, msg.Id)
resp := drainChannel(ch)
require.NotNil(t, resp, "proxy %d should receive a message", i)
require.Len(t, resp.Mapping, 1, "proxy %d should receive exactly one mapping", i)
msg := resp.Mapping[0]
assert.Equal(t, mapping.Domain, msg.Domain)
assert.Equal(t, mapping.Id, msg.Id)
assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i)
tokens[i] = msg.AuthToken
}
@@ -96,66 +155,84 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
}
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour)
defer tokenStore.Close()
ctx := context.Background()
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{
tokenStore: tokenStore,
updatesChan: make(chan *proto.ProxyMapping, 100),
tokenStore: tokenStore,
pkceVerifierStore: pkceStore,
}
s.SetProxyController(newTestProxyController())
const cluster = "proxy.example.com"
ch1 := registerFakeProxy(s, "proxy-a", cluster)
ch2 := registerFakeProxy(s, "proxy-b", cluster)
update := &proto.ProxyMapping{
mapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
Id: "service-1",
AccountId: "account-1",
Domain: "test.example.com",
}
s.SendServiceUpdateToCluster(update, cluster)
s.SendServiceUpdateToCluster(context.Background(), mapping, cluster)
msg1 := drainChannel(ch1)
msg2 := drainChannel(ch2)
require.NotNil(t, msg1)
require.NotNil(t, msg2)
resp1 := drainChannel(ch1)
resp2 := drainChannel(ch2)
require.NotNil(t, resp1)
require.NotNil(t, resp2)
require.Len(t, resp1.Mapping, 1)
require.Len(t, resp2.Mapping, 1)
// Delete operations should not generate tokens
assert.Empty(t, msg1.AuthToken)
assert.Empty(t, msg2.AuthToken)
// No tokens should have been created
assert.Equal(t, 0, tokenStore.GetTokenCount())
assert.Empty(t, resp1.Mapping[0].AuthToken)
assert.Empty(t, resp2.Mapping[0].AuthToken)
}
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour)
defer tokenStore.Close()
ctx := context.Background()
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{
tokenStore: tokenStore,
updatesChan: make(chan *proto.ProxyMapping, 100),
tokenStore: tokenStore,
pkceVerifierStore: pkceStore,
}
s.SetProxyController(newTestProxyController())
// Register proxies in different clusters (SendServiceUpdate broadcasts to all)
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")
ch2 := registerFakeProxy(s, "proxy-b", "cluster-b")
update := &proto.ProxyMapping{
mapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-1",
AccountId: "account-1",
Domain: "test.example.com",
}
update := &proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{mapping},
}
s.SendServiceUpdate(update)
msg1 := drainChannel(ch1)
msg2 := drainChannel(ch2)
require.NotNil(t, msg1)
require.NotNil(t, msg2)
resp1 := drainChannel(ch1)
resp2 := drainChannel(ch2)
require.NotNil(t, resp1)
require.NotNil(t, resp2)
require.Len(t, resp1.Mapping, 1)
require.Len(t, resp2.Mapping, 1)
msg1 := resp1.Mapping[0]
msg2 := resp2.Mapping[0]
assert.NotEmpty(t, msg1.AuthToken)
assert.NotEmpty(t, msg2.AuthToken)
@@ -178,10 +255,15 @@ func generateState(s *ProxyServiceServer, redirectURL string) string {
}
func TestOAuthState_NeverTheSame(t *testing.T) {
ctx := context.Background()
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
pkceVerifierStore: pkceStore,
}
redirectURL := "https://app.example.com/callback"
@@ -202,31 +284,43 @@ func TestOAuthState_NeverTheSame(t *testing.T) {
}
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
ctx := context.Background()
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
pkceVerifierStore: pkceStore,
}
// Old format had only 2 parts: base64(url)|hmac
s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
err = s.pkceVerifierStore.Store("base64url|hmac", "test", 10*time.Minute)
require.NoError(t, err)
_, _, err := s.ValidateState("base64url|hmac")
_, _, err = s.ValidateState("base64url|hmac")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid state format")
}
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
ctx := context.Background()
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
pkceVerifierStore: pkceStore,
}
// Store with tampered HMAC
s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
err = s.pkceVerifierStore.Store("dGVzdA==|nonce|wrong-hmac", "test", 10*time.Minute)
require.NoError(t, err)
_, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac")
_, _, err = s.ValidateState("dGVzdA==|nonce|wrong-hmac")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid state signature")
}

View File

@@ -26,6 +26,7 @@ import (
"github.com/netbirdio/netbird/shared/management/client/common"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/job"
@@ -80,6 +81,9 @@ type Server struct {
syncSem atomic.Int32
syncLimEnabled bool
syncLim int32
reverseProxyManager rpservice.Manager
reverseProxyMu sync.RWMutex
}
// NewServer creates a new Management server
@@ -326,13 +330,12 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
s.secretsManager.SetupRefresh(ctx, accountID, peer.ID)
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID)
}
unlock()
unlock = nil
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID)
}
log.WithContext(ctx).Debugf("Sync took %s", time.Since(reqStart))
s.syncSem.Add(-1)
@@ -739,13 +742,6 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP)
defer func() {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
}
log.WithContext(ctx).Debugf("Login took %s", time.Since(reqStart))
}()
if loginReq.GetMeta() == nil {
msg := status.Errorf(codes.FailedPrecondition,
"peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(), realIP)
@@ -795,6 +791,11 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
return nil, status.Errorf(codes.Internal, "failed logging in peer")
}
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
}
log.WithContext(ctx).Debugf("Login took %s", time.Since(reqStart))
return &proto.EncryptedMessage{
WgPubKey: key.PublicKey().String(),
Body: encryptedResp,

View File

@@ -13,7 +13,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -34,11 +34,18 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir())
require.NoError(t, err)
proxyManager := &testValidateSessionProxyManager{store: testStore}
serviceManager := &testValidateSessionServiceManager{store: testStore}
usersManager := &testValidateSessionUsersManager{store: testStore}
proxyManager := &testValidateSessionProxyManager{}
proxyService := NewProxyServiceServer(nil, NewOneTimeTokenStore(time.Minute), ProxyOIDCConfig{}, nil, usersManager)
proxyService.SetProxyManager(proxyManager)
tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
proxyService.SetServiceManager(serviceManager)
createTestProxies(t, ctx, testStore)
@@ -54,7 +61,7 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
pubKey, privKey := generateSessionKeyPair(t)
testProxy := &reverseproxy.Service{
testProxy := &service.Service{
ID: "testProxyId",
AccountID: "testAccountId",
Name: "Test Proxy",
@@ -62,15 +69,15 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
Enabled: true,
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
},
},
}
require.NoError(t, testStore.CreateService(ctx, testProxy))
restrictedProxy := &reverseproxy.Service{
restrictedProxy := &service.Service{
ID: "restrictedProxyId",
AccountID: "testAccountId",
Name: "Restricted Proxy",
@@ -78,8 +85,8 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
Enabled: true,
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"allowedGroupId"},
},
@@ -196,7 +203,7 @@ func TestValidateSession_ProxyNotFound(t *testing.T) {
require.NoError(t, err)
assert.False(t, resp.Valid, "Unknown proxy should be denied")
assert.Equal(t, "proxy_not_found", resp.DeniedReason)
assert.Equal(t, "service_not_found", resp.DeniedReason)
}
func TestValidateSession_InvalidToken(t *testing.T) {
@@ -239,62 +246,102 @@ func TestValidateSession_MissingToken(t *testing.T) {
assert.Contains(t, resp.DeniedReason, "missing")
}
type testValidateSessionProxyManager struct {
type testValidateSessionServiceManager struct {
store store.Store
}
func (m *testValidateSessionProxyManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
func (m *testValidateSessionServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*service.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
func (m *testValidateSessionServiceManager) GetService(_ context.Context, _, _, _ string) (*service.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *testValidateSessionServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *testValidateSessionServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) DeleteService(_ context.Context, _, _, _ string) error {
func (m *testValidateSessionServiceManager) DeleteService(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
func (m *testValidateSessionServiceManager) DeleteAllServices(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
func (m *testValidateSessionServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
func (m *testValidateSessionServiceManager) SetStatus(_ context.Context, _, _ string, _ service.Status) error {
return nil
}
func (m *testValidateSessionProxyManager) ReloadService(_ context.Context, _, _ string) error {
func (m *testValidateSessionServiceManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
func (m *testValidateSessionServiceManager) ReloadService(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
return m.store.GetServices(ctx, store.LockingStrengthNone)
}
func (m *testValidateSessionProxyManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
func (m *testValidateSessionServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*service.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
}
func (m *testValidateSessionProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
func (m *testValidateSessionServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}
func (m *testValidateSessionProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
func (m *testValidateSessionServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil
}
func (m *testValidateSessionServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
return nil, nil
}
func (m *testValidateSessionServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {}
type testValidateSessionProxyManager struct{}
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Context) ([]string, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time.Duration) error {
return nil
}
type testValidateSessionUsersManager struct {
store store.Store
}

View File

@@ -15,7 +15,7 @@ import (
"sync"
"time"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/shared/auth"
@@ -83,9 +83,9 @@ type DefaultAccountManager struct {
requestBuffer *AccountRequestBuffer
proxyController port_forwarding.Controller
settingsManager settings.Manager
reverseProxyManager reverseproxy.Manager
proxyController port_forwarding.Controller
settingsManager settings.Manager
serviceManager service.Manager
// config contains the management server configuration
config *nbconfig.Config
@@ -115,8 +115,8 @@ type DefaultAccountManager struct {
var _ account.Manager = (*DefaultAccountManager)(nil)
func (am *DefaultAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
am.reverseProxyManager = serviceManager
func (am *DefaultAccountManager) SetServiceManager(serviceManager service.Manager) {
am.serviceManager = serviceManager
}
func isUniqueConstraintError(err error) bool {
@@ -376,6 +376,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleAutoUpdateVersionSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handlePeerExposeSettings(ctx, oldSettings, newSettings, userID, accountID)
if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil {
return nil, err
}
@@ -394,7 +395,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta)
}
if reloadReverseProxy {
if err = am.reverseProxyManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
if err = am.serviceManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err)
}
}
@@ -492,6 +493,21 @@ func (am *DefaultAccountManager) handleAutoUpdateVersionSettings(ctx context.Con
}
}
func (am *DefaultAccountManager) handlePeerExposeSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
oldEnabled := oldSettings.PeerExposeEnabled
newEnabled := newSettings.PeerExposeEnabled
if oldEnabled == newEnabled {
return
}
event := activity.AccountPeerExposeEnabled
if !newEnabled {
event = activity.AccountPeerExposeDisabled
}
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
}
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error {
if newSettings.PeerInactivityExpirationEnabled {
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
@@ -714,6 +730,11 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
}
err = am.serviceManager.DeleteAllServices(ctx, accountID, userID)
if err != nil {
return status.Errorf(status.Internal, "failed to delete service %s: %v", accountID, err)
}
for _, otherUser := range account.Users {
if otherUser.Id == userID {
continue
@@ -1358,9 +1379,10 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations.
// We override incoming domain claims to group users under a single account.
userAuth.Domain = am.singleAccountModeDomain
userAuth.DomainCategory = types.PrivateCategory
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
err := am.updateUserAuthWithSingleMode(ctx, &userAuth)
if err != nil {
return "", "", err
}
}
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, userAuth)
@@ -1393,6 +1415,35 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
return accountID, user.Id, nil
}
// updateUserAuthWithSingleMode modifies the userAuth with the single account domain, or if there is an existing account, with the domain of that account
func (am *DefaultAccountManager) updateUserAuthWithSingleMode(ctx context.Context, userAuth *auth.UserAuth) error {
userAuth.DomainCategory = types.PrivateCategory
userAuth.Domain = am.singleAccountModeDomain
accountID, err := am.Store.GetAnyAccountID(ctx)
if err != nil {
if e, ok := status.FromError(err); !ok || e.Type() != status.NotFound {
return err
}
log.WithContext(ctx).Debugf("using singleAccountModeDomain to override JWT Domain and DomainCategory claims in single account mode")
return nil
}
if accountID == "" {
log.WithContext(ctx).Debugf("using singleAccountModeDomain to override JWT Domain and DomainCategory claims in single account mode")
return nil
}
domain, _, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
userAuth.Domain = domain
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
return nil
}
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
// and propagates changes to peers if group propagation is enabled.
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager

View File

@@ -1,12 +1,14 @@
package account
//go:generate go run github.com/golang/mock/mockgen -package account -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
import (
"context"
"net"
"net/netip"
"time"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/shared/auth"
nbdns "github.com/netbirdio/netbird/dns"
@@ -61,11 +63,11 @@ type Manager interface {
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
@@ -140,5 +142,5 @@ type Manager interface {
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
SetServiceManager(serviceManager reverseproxy.Manager)
SetServiceManager(serviceManager service.Manager)
}

File diff suppressed because it is too large Load Diff

View File

@@ -15,10 +15,12 @@ import (
"time"
"github.com/golang/mock/gomock"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/prometheus/client_golang/prometheus/push"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/metric/noop"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbdns "github.com/netbirdio/netbird/dns"
@@ -27,10 +29,13 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/cache"
@@ -1802,12 +1807,12 @@ func TestAccount_Copy(t *testing.T) {
Address: "172.12.6.1/24",
},
},
Services: []*reverseproxy.Service{
Services: []*service.Service{
{
ID: "service1",
Name: "test-service",
AccountID: "account1",
Targets: []*reverseproxy.Target{},
Targets: []*service.Target{},
},
},
NetworkMapCache: &types.NetworkMapBuilder{},
@@ -3112,6 +3117,12 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
proxyManager := proxy.NewMockManager(ctrl)
proxyManager.EXPECT().
CleanupStale(gomock.Any(), gomock.Any()).
Return(nil).
AnyTimes()
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
@@ -3122,7 +3133,12 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
return nil, nil, err
}
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, nil, nil))
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
if err != nil {
return nil, nil, err
}
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, nil))
return manager, updateManager, nil
}
@@ -3951,3 +3967,116 @@ func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange(t *testi
t.Fatal("UpdateAccountSettings deadlocked when changing NetworkRange")
}
}
func TestUpdateUserAuthWithSingleMode(t *testing.T) {
t.Run("sets defaults and overrides domain from store", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("account-1", nil)
mockStore.EXPECT().
GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "account-1").
Return("real-domain.com", "private", nil)
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.NoError(t, err)
assert.Equal(t, "real-domain.com", userAuth.Domain)
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
})
t.Run("falls back to singleAccountModeDomain when account ID is empty", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("", nil)
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.NoError(t, err)
assert.Equal(t, "fallback.com", userAuth.Domain)
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
})
t.Run("falls back to singleAccountModeDomain on NotFound error", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("", status.Errorf(status.NotFound, "no accounts"))
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.NoError(t, err)
assert.Equal(t, "fallback.com", userAuth.Domain)
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
})
t.Run("propagates non-NotFound error from GetAnyAccountID", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("", status.Errorf(status.Internal, "db down"))
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.Error(t, err)
assert.Contains(t, err.Error(), "db down")
// Defaults should still be set before error path
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
})
t.Run("propagates error from GetAccountDomainAndCategory", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("account-1", nil)
mockStore.EXPECT().
GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "account-1").
Return("", "", status.Errorf(status.Internal, "query failed"))
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.Error(t, err)
assert.Contains(t, err.Error(), "query failed")
})
}

View File

@@ -208,6 +208,25 @@ const (
ServiceUpdated Activity = 109
ServiceDeleted Activity = 110
// PeerServiceExposed indicates that a peer exposed a service via the reverse proxy
PeerServiceExposed Activity = 111
// PeerServiceUnexposed indicates that a peer-exposed service was removed
PeerServiceUnexposed Activity = 112
// PeerServiceExposeExpired indicates that a peer-exposed service was removed due to TTL expiration
PeerServiceExposeExpired Activity = 113
// AccountPeerExposeEnabled indicates that a user enabled peer expose for the account
AccountPeerExposeEnabled Activity = 114
// AccountPeerExposeDisabled indicates that a user disabled peer expose for the account
AccountPeerExposeDisabled Activity = 115
// DomainAdded indicates that a user added a custom domain
DomainAdded Activity = 118
// DomainDeleted indicates that a user deleted a custom domain
DomainDeleted Activity = 119
// DomainValidated indicates that a custom domain was validated
DomainValidated Activity = 120
AccountDeleted Activity = 99999
)
@@ -345,6 +364,17 @@ var activityMap = map[Activity]Code{
ServiceCreated: {"Service created", "service.create"},
ServiceUpdated: {"Service updated", "service.update"},
ServiceDeleted: {"Service deleted", "service.delete"},
PeerServiceExposed: {"Peer exposed service", "service.peer.expose"},
PeerServiceUnexposed: {"Peer unexposed service", "service.peer.unexpose"},
PeerServiceExposeExpired: {"Peer exposed service expired", "service.peer.expose.expire"},
AccountPeerExposeEnabled: {"Account peer expose enabled", "account.setting.peer.expose.enable"},
AccountPeerExposeDisabled: {"Account peer expose disabled", "account.setting.peer.expose.disable"},
DomainAdded: {"Domain added", "domain.add"},
DomainDeleted: {"Domain deleted", "domain.delete"},
DomainValidated: {"Domain validated", "domain.validate"},
}
// StringCode returns a string code of the activity

View File

@@ -249,7 +249,15 @@ func initDatabase(ctx context.Context, dataDir string) (*gorm.DB, error) {
switch storeEngine {
case types.SqliteStoreEngine:
dialector = sqlite.Open(filepath.Join(dataDir, eventSinkDB))
dbFile := eventSinkDB
if envFile, ok := os.LookupEnv("NB_ACTIVITY_EVENT_SQLITE_FILE"); ok && envFile != "" {
dbFile = envFile
}
connStr := dbFile
if !filepath.IsAbs(dbFile) {
connStr = filepath.Join(dataDir, dbFile)
}
dialector = sqlite.Open(connStr)
case types.PostgresStoreEngine:
dsn, ok := os.LookupEnv(postgresDsnEnv)
if !ok {

View File

@@ -425,6 +425,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
var groupIDsToDelete []string
var deletedGroups []*types.Group
extraSettings, err := am.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
return err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
for _, groupID := range groupIDs {
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
@@ -433,7 +438,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
continue
}
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
if err = validateDeleteGroup(ctx, transaction, group, userID, extraSettings.FlowGroups); err != nil {
allErrors = errors.Join(allErrors, err)
continue
}
@@ -621,7 +626,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
return nil
}
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error {
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string, flowGroups []string) error {
// disable a deleting integration group if the initiator is not an admin service user
if group.Issued == types.GroupIssuedIntegration {
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
@@ -641,6 +646,10 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
return &GroupLinkError{"network resource", group.Resources[0].ID}
}
if slices.Contains(flowGroups, group.ID) {
return &GroupLinkError{"settings", "traffic event logging"}
}
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)}
}

View File

@@ -12,6 +12,7 @@ import (
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -26,6 +27,7 @@ import (
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
peer2 "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
@@ -284,6 +286,67 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
}
}
func TestDefaultAccountManager_DeleteGroupLinkedToFlowGroup(t *testing.T) {
am, _, err := createManager(t)
require.NoError(t, err)
ctrl := gomock.NewController(t)
settingsMock := settings.NewMockManager(ctrl)
settingsMock.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{FlowGroups: []string{"grp-for-flow"}}, nil).
AnyTimes()
settingsMock.EXPECT().
UpdateExtraSettings(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(false, nil).
AnyTimes()
am.settingsManager = settingsMock
_, account, err := initTestGroupAccount(am)
require.NoError(t, err)
grp := &types.Group{
ID: "grp-for-flow",
AccountID: account.Id,
Name: "Group for flow",
Issued: types.GroupIssuedAPI,
Peers: make([]string, 0),
}
require.NoError(t, am.CreateGroup(context.Background(), account.Id, groupAdminUserID, grp))
err = am.DeleteGroup(context.Background(), account.Id, groupAdminUserID, "grp-for-flow")
require.Error(t, err)
var gErr *GroupLinkError
require.ErrorAs(t, err, &gErr)
assert.Equal(t, "settings", gErr.Resource)
assert.Equal(t, "traffic event logging", gErr.Name)
group, err := am.GetGroup(context.Background(), account.Id, "grp-for-flow", groupAdminUserID)
require.NoError(t, err)
assert.NotNil(t, group)
regularGrp := &types.Group{
ID: "grp-regular",
AccountID: account.Id,
Name: "Regular group",
Issued: types.GroupIssuedAPI,
Peers: make([]string, 0),
}
err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, regularGrp)
require.NoError(t, err)
err = am.DeleteGroups(context.Background(), account.Id, groupAdminUserID, []string{"grp-for-flow", "grp-regular"})
require.Error(t, err)
group, err = am.GetGroup(context.Background(), account.Id, "grp-for-flow", groupAdminUserID)
require.NoError(t, err)
assert.NotNil(t, group)
_, err = am.GetGroup(context.Background(), account.Id, "grp-regular", groupAdminUserID)
assert.Error(t, err)
}
func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *types.Account, error) {
accountID := "testingAcc"
domain := "example.com"
@@ -703,7 +766,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
t.Run("saving group linked to network router", func(t *testing.T) {
permissionsManager := permissions.NewManager(manager.Store)
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.reverseProxyManager)
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)

View File

@@ -17,9 +17,9 @@ import (
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
idpmanager "github.com/netbirdio/netbird/management/server/idp"
@@ -73,7 +73,7 @@ const (
)
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil {
@@ -173,8 +173,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
idp.AddEndpoints(accountManager, router)
instance.AddEndpoints(instanceManager, router)
instance.AddVersionEndpoint(instanceManager, router)
if reverseProxyManager != nil && reverseProxyDomainManager != nil {
reverseproxymanager.RegisterEndpoints(reverseProxyManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router)
if serviceManager != nil && reverseProxyDomainManager != nil {
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router)
}
// Register OAuth callback handler for proxy authentication

View File

@@ -168,6 +168,10 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
}
func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJSONRequestBody) (*types.Settings, error) {
if req.Settings.PeerExposeEnabled && len(req.Settings.PeerExposeGroups) == 0 {
return nil, status.Errorf(status.InvalidArgument, "peer expose requires at least one group")
}
returnSettings := &types.Settings{
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
@@ -175,6 +179,9 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS
PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
PeerExposeEnabled: req.Settings.PeerExposeEnabled,
PeerExposeGroups: req.Settings.PeerExposeGroups,
}
if req.Settings.Extra != nil {
@@ -336,6 +343,8 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
JwtAllowGroups: &jwtAllowGroups,
RegularUsersViewBlocked: settings.RegularUsersViewBlocked,
RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled,
PeerExposeEnabled: settings.PeerExposeEnabled,
PeerExposeGroups: settings.PeerExposeGroups,
LazyConnectionEnabled: &settings.LazyConnectionEnabled,
DnsDomain: &settings.DNSDomain,
AutoUpdateVersion: &settings.AutoUpdateVersion,

View File

@@ -18,8 +18,8 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -190,7 +190,11 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
oidcServer := newFakeOIDCServer()
tokenStore := nbgrpc.NewOneTimeTokenStore(time.Minute)
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
usersManager := users.NewManager(testStore)
@@ -205,12 +209,14 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
proxyService := nbgrpc.NewProxyServiceServer(
&testAccessLogManager{},
tokenStore,
pkceStore,
oidcConfig,
nil,
usersManager,
nil,
)
proxyService.SetProxyManager(&testServiceManager{store: testStore})
proxyService.SetServiceManager(&testServiceManager{store: testStore})
handler := NewAuthCallbackHandler(proxyService, nil)
@@ -239,12 +245,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
pubKey := base64.StdEncoding.EncodeToString(pub)
privKey := base64.StdEncoding.EncodeToString(priv)
testProxy := &reverseproxy.Service{
testProxy := &service.Service{
ID: "testProxyId",
AccountID: "testAccountId",
Name: "Test Proxy",
Domain: "test-proxy.example.com",
Targets: []*reverseproxy.Target{{
Targets: []*service.Target{{
Path: strPtr("/"),
Host: "localhost",
Port: 8080,
@@ -254,8 +260,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
Enabled: true,
}},
Enabled: true,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"allowedGroupId"},
},
@@ -265,12 +271,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
}
require.NoError(t, testStore.CreateService(ctx, testProxy))
restrictedProxy := &reverseproxy.Service{
restrictedProxy := &service.Service{
ID: "restrictedProxyId",
AccountID: "testAccountId",
Name: "Restricted Proxy",
Domain: "restricted-proxy.example.com",
Targets: []*reverseproxy.Target{{
Targets: []*service.Target{{
Path: strPtr("/"),
Host: "localhost",
Port: 8080,
@@ -280,8 +286,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
Enabled: true,
}},
Enabled: true,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"restrictedGroupId"},
},
@@ -291,12 +297,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
}
require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
noAuthProxy := &reverseproxy.Service{
noAuthProxy := &service.Service{
ID: "noAuthProxyId",
AccountID: "testAccountId",
Name: "No Auth Proxy",
Domain: "no-auth-proxy.example.com",
Targets: []*reverseproxy.Target{{
Targets: []*service.Target{{
Path: strPtr("/"),
Host: "localhost",
Port: 8080,
@@ -306,8 +312,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
Enabled: true,
}},
Enabled: true,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: false,
},
},
@@ -357,19 +363,23 @@ type testServiceManager struct {
store store.Store
}
func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
func (m *testServiceManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
return nil
}
func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*service.Service, error) {
return nil, nil
}
func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*service.Service, error) {
return nil, nil
}
func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, nil
}
func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, nil
}
@@ -381,7 +391,7 @@ func (m *testServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ stri
return nil
}
func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ service.Status) error {
return nil
}
@@ -393,15 +403,15 @@ func (m *testServiceManager) ReloadService(_ context.Context, _, _ string) error
return nil
}
func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
return m.store.GetServices(ctx, store.LockingStrengthNone)
}
func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*service.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
}
func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}
@@ -409,6 +419,20 @@ func (m *testServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ stri
return "", nil
}
func (m *testServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
return nil, nil
}
func (m *testServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testServiceManager) StartExposeReaper(_ context.Context) {}
func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string {
t.Helper()

View File

@@ -9,10 +9,13 @@ import (
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"go.opentelemetry.io/otel/metric/noop"
"github.com/netbirdio/management-integrations/integrations"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
@@ -91,12 +94,28 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
}
accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil)
proxyTokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager)
domainManager := manager.NewManager(store, proxyServiceServer, permissionsManager)
reverseProxyManager := reverseproxymanager.NewManager(store, am, permissionsManager, proxyServiceServer, domainManager)
proxyServiceServer.SetProxyManager(reverseProxyManager)
am.SetServiceManager(reverseProxyManager)
proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
if err != nil {
t.Fatalf("Failed to create proxy token store: %v", err)
}
pkceverifierStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
if err != nil {
t.Fatalf("Failed to create PKCE verifier store: %v", err)
}
noopMeter := noop.NewMeterProvider().Meter("")
proxyMgr, err := proxymanager.NewManager(store, noopMeter)
if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err)
}
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil {
t.Fatalf("Failed to create proxy controller: %v", err)
}
serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, domainManager)
proxyServiceServer.SetServiceManager(serviceManager)
am.SetServiceManager(serviceManager)
// @note this is required so that PAT's validate from store, but JWT's are mocked
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
@@ -114,7 +133,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, reverseProxyManager, nil, nil, nil, nil)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}

Some files were not shown because too many files have changed in this diff Show More