mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-25 19:02:13 -04:00
Compare commits
21 Commits
v0.69.0
...
add-packet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
33e166fde2 | ||
|
|
3da5c7aad2 | ||
|
|
635343f14e | ||
|
|
c07c726ea7 | ||
|
|
fa0d58d093 | ||
|
|
b6038e8acd | ||
|
|
5da05ecca6 | ||
|
|
801de8c68d | ||
|
|
a822a33240 | ||
|
|
57b23c5b25 | ||
|
|
1165058fad | ||
|
|
703353d354 | ||
|
|
2fb50aef6b | ||
|
|
eb3aa96257 | ||
|
|
064ec1c832 | ||
|
|
75e408f51c | ||
|
|
5a89e6621b | ||
|
|
06dfa9d4a5 | ||
|
|
45d9ee52c0 | ||
|
|
b734534f3c | ||
|
|
e58c29d4f9 |
@@ -17,6 +17,7 @@ ENV \
|
||||
NETBIRD_BIN="/usr/local/bin/netbird" \
|
||||
NB_LOG_FILE="console,/var/log/netbird/client.log" \
|
||||
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
|
||||
NB_ENABLE_CAPTURE="false" \
|
||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
|
||||
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||
|
||||
@@ -23,6 +23,7 @@ ENV \
|
||||
NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \
|
||||
NB_LOG_FILE="console,/var/lib/netbird/client.log" \
|
||||
NB_DISABLE_DNS="true" \
|
||||
NB_ENABLE_CAPTURE="false" \
|
||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
|
||||
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||
|
||||
196
client/cmd/capture.go
Normal file
196
client/cmd/capture.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/util/capture"
|
||||
)
|
||||
|
||||
var captureCmd = &cobra.Command{
|
||||
Use: "capture",
|
||||
Short: "Capture packets on the WireGuard interface",
|
||||
Long: `Captures decrypted packets flowing through the WireGuard interface.
|
||||
|
||||
Default output is human-readable text. Use --pcap or --output for pcap binary.
|
||||
Requires --enable-capture to be set at service install or reconfigure time.
|
||||
|
||||
Examples:
|
||||
netbird debug capture
|
||||
netbird debug capture host 100.64.0.1 and port 443
|
||||
netbird debug capture tcp
|
||||
netbird debug capture icmp
|
||||
netbird debug capture src host 10.0.0.1 and dst port 80
|
||||
netbird debug capture -o capture.pcap
|
||||
netbird debug capture --pcap | tshark -r -
|
||||
netbird debug capture --pcap | tcpdump -r - -n`,
|
||||
Args: cobra.ArbitraryArgs,
|
||||
RunE: runCapture,
|
||||
}
|
||||
|
||||
func init() {
|
||||
debugCmd.AddCommand(captureCmd)
|
||||
|
||||
captureCmd.Flags().Bool("pcap", false, "Force pcap binary output (default when --output is set)")
|
||||
captureCmd.Flags().BoolP("verbose", "v", false, "Show seq/ack, TTL, window, total length")
|
||||
captureCmd.Flags().Bool("ascii", false, "Print payload as ASCII after each packet (useful for HTTP)")
|
||||
captureCmd.Flags().Uint32("snap-len", 0, "Max bytes per packet (0 = full)")
|
||||
captureCmd.Flags().DurationP("duration", "d", 0, "Capture duration (0 = until interrupted)")
|
||||
captureCmd.Flags().StringP("output", "o", "", "Write pcap to file instead of stdout")
|
||||
}
|
||||
|
||||
func runCapture(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
cmd.PrintErrf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
req, err := buildCaptureRequest(cmd, args)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM)
|
||||
defer cancel()
|
||||
|
||||
stream, err := client.StartCapture(ctx, req)
|
||||
if err != nil {
|
||||
return handleCaptureError(err)
|
||||
}
|
||||
|
||||
// First Recv is the empty acceptance message from the server. If the
|
||||
// device is unavailable (kernel WG, not connected, capture disabled),
|
||||
// the server returns an error instead.
|
||||
if _, err := stream.Recv(); err != nil {
|
||||
return handleCaptureError(err)
|
||||
}
|
||||
|
||||
out, cleanup, err := captureOutput(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if req.TextOutput {
|
||||
cmd.PrintErrf("Capturing packets... Press Ctrl+C to stop.\n")
|
||||
} else {
|
||||
cmd.PrintErrf("Capturing packets (pcap)... Press Ctrl+C to stop.\n")
|
||||
}
|
||||
|
||||
streamErr := streamCapture(ctx, cmd, stream, out)
|
||||
cleanupErr := cleanup()
|
||||
if streamErr != nil {
|
||||
return streamErr
|
||||
}
|
||||
return cleanupErr
|
||||
}
|
||||
|
||||
func buildCaptureRequest(cmd *cobra.Command, args []string) (*proto.StartCaptureRequest, error) {
|
||||
req := &proto.StartCaptureRequest{}
|
||||
|
||||
if len(args) > 0 {
|
||||
expr := strings.Join(args, " ")
|
||||
if _, err := capture.ParseFilter(expr); err != nil {
|
||||
return nil, fmt.Errorf("invalid filter: %w", err)
|
||||
}
|
||||
req.FilterExpr = expr
|
||||
}
|
||||
|
||||
if snap, _ := cmd.Flags().GetUint32("snap-len"); snap > 0 {
|
||||
req.SnapLen = snap
|
||||
}
|
||||
if d, _ := cmd.Flags().GetDuration("duration"); d != 0 {
|
||||
if d < 0 {
|
||||
return nil, fmt.Errorf("duration must not be negative")
|
||||
}
|
||||
req.Duration = durationpb.New(d)
|
||||
}
|
||||
req.Verbose, _ = cmd.Flags().GetBool("verbose")
|
||||
req.Ascii, _ = cmd.Flags().GetBool("ascii")
|
||||
|
||||
outPath, _ := cmd.Flags().GetString("output")
|
||||
forcePcap, _ := cmd.Flags().GetBool("pcap")
|
||||
req.TextOutput = !forcePcap && outPath == ""
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func streamCapture(ctx context.Context, cmd *cobra.Command, stream proto.DaemonService_StartCaptureClient, out io.Writer) error {
|
||||
for {
|
||||
pkt, err := stream.Recv()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
cmd.PrintErrf("\nCapture stopped.\n")
|
||||
return nil //nolint:nilerr // user interrupted
|
||||
}
|
||||
if err == io.EOF {
|
||||
cmd.PrintErrf("\nCapture finished.\n")
|
||||
return nil
|
||||
}
|
||||
return handleCaptureError(err)
|
||||
}
|
||||
if _, err := out.Write(pkt.GetData()); err != nil {
|
||||
return fmt.Errorf("write output: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// captureOutput returns the writer for capture data and a cleanup function
|
||||
// that finalizes the file. Errors from the cleanup must be propagated.
|
||||
func captureOutput(cmd *cobra.Command) (io.Writer, func() error, error) {
|
||||
outPath, _ := cmd.Flags().GetString("output")
|
||||
if outPath == "" {
|
||||
return os.Stdout, func() error { return nil }, nil
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp(filepath.Dir(outPath), filepath.Base(outPath)+".*.tmp")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create output file: %w", err)
|
||||
}
|
||||
tmpPath := f.Name()
|
||||
return f, func() error {
|
||||
var merr *multierror.Error
|
||||
if err := f.Close(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("close output file: %w", err))
|
||||
}
|
||||
fi, statErr := os.Stat(tmpPath)
|
||||
if statErr != nil || fi.Size() == 0 {
|
||||
if rmErr := os.Remove(tmpPath); rmErr != nil && !os.IsNotExist(rmErr) {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove empty output file: %w", rmErr))
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
if err := os.Rename(tmpPath, outPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("rename output file: %w", err))
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
cmd.PrintErrf("Wrote %s\n", outPath)
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func handleCaptureError(err error) error {
|
||||
if s, ok := status.FromError(err); ok {
|
||||
return fmt.Errorf("%s", s.Message())
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
@@ -239,11 +240,50 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
}()
|
||||
}
|
||||
|
||||
captureStarted := false
|
||||
if wantCapture, _ := cmd.Flags().GetBool("capture"); wantCapture {
|
||||
captureTimeout := duration + 30*time.Second
|
||||
const maxBundleCapture = 10 * time.Minute
|
||||
if captureTimeout > maxBundleCapture {
|
||||
captureTimeout = maxBundleCapture
|
||||
}
|
||||
_, err := client.StartBundleCapture(cmd.Context(), &proto.StartBundleCaptureRequest{
|
||||
Timeout: durationpb.New(captureTimeout),
|
||||
})
|
||||
if err != nil {
|
||||
cmd.PrintErrf("Failed to start packet capture: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
captureStarted = true
|
||||
cmd.Println("Packet capture started.")
|
||||
// Safety: always stop on exit, even if the normal stop below runs too.
|
||||
defer func() {
|
||||
if captureStarted {
|
||||
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to stop packet capture: %v\n", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
||||
return waitErr
|
||||
}
|
||||
cmd.Println("\nDuration completed")
|
||||
|
||||
if captureStarted {
|
||||
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to stop packet capture: %v\n", err)
|
||||
} else {
|
||||
captureStarted = false
|
||||
cmd.Println("Packet capture stopped.")
|
||||
}
|
||||
}
|
||||
|
||||
if cpuProfilingStarted {
|
||||
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
|
||||
@@ -416,4 +456,5 @@ func init() {
|
||||
forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
|
||||
forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
|
||||
forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
|
||||
forCmd.Flags().Bool("capture", false, "Capture packets during the debug duration and include in bundle")
|
||||
}
|
||||
|
||||
@@ -75,6 +75,7 @@ var (
|
||||
mtu uint16
|
||||
profilesDisabled bool
|
||||
updateSettingsDisabled bool
|
||||
captureEnabled bool
|
||||
networksDisabled bool
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
|
||||
@@ -44,6 +44,7 @@ func init() {
|
||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
|
||||
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
|
||||
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
|
||||
serviceCmd.PersistentFlags().BoolVar(&captureEnabled, "enable-capture", false, "Enables packet capture via 'netbird debug capture'. To persist, use: netbird service install --enable-capture")
|
||||
serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks")
|
||||
|
||||
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
||||
|
||||
@@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error {
|
||||
}
|
||||
}
|
||||
|
||||
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, networksDisabled)
|
||||
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, captureEnabled, networksDisabled)
|
||||
if err := serverInstance.Start(); err != nil {
|
||||
log.Fatalf("failed to start daemon: %v", err)
|
||||
}
|
||||
|
||||
@@ -59,6 +59,10 @@ func buildServiceArguments() []string {
|
||||
args = append(args, "--disable-update-settings")
|
||||
}
|
||||
|
||||
if captureEnabled {
|
||||
args = append(args, "--enable-capture")
|
||||
}
|
||||
|
||||
if networksDisabled {
|
||||
args = append(args, "--disable-networks")
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ type serviceParams struct {
|
||||
LogFiles []string `json:"log_files,omitempty"`
|
||||
DisableProfiles bool `json:"disable_profiles,omitempty"`
|
||||
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
|
||||
EnableCapture bool `json:"enable_capture,omitempty"`
|
||||
DisableNetworks bool `json:"disable_networks,omitempty"`
|
||||
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
|
||||
}
|
||||
@@ -79,6 +80,7 @@ func currentServiceParams() *serviceParams {
|
||||
LogFiles: logFiles,
|
||||
DisableProfiles: profilesDisabled,
|
||||
DisableUpdateSettings: updateSettingsDisabled,
|
||||
EnableCapture: captureEnabled,
|
||||
DisableNetworks: networksDisabled,
|
||||
}
|
||||
|
||||
@@ -144,6 +146,10 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
|
||||
updateSettingsDisabled = params.DisableUpdateSettings
|
||||
}
|
||||
|
||||
if !serviceCmd.PersistentFlags().Changed("enable-capture") {
|
||||
captureEnabled = params.EnableCapture
|
||||
}
|
||||
|
||||
if !serviceCmd.PersistentFlags().Changed("disable-networks") {
|
||||
networksDisabled = params.DisableNetworks
|
||||
}
|
||||
|
||||
@@ -535,6 +535,7 @@ func fieldToGlobalVar(field string) string {
|
||||
"LogFiles": "logFiles",
|
||||
"DisableProfiles": "profilesDisabled",
|
||||
"DisableUpdateSettings": "updateSettingsDisabled",
|
||||
"EnableCapture": "captureEnabled",
|
||||
"DisableNetworks": "networksDisabled",
|
||||
"ServiceEnvVars": "serviceEnvVars",
|
||||
}
|
||||
|
||||
@@ -160,7 +160,7 @@ func startClientDaemon(
|
||||
s := grpc.NewServer()
|
||||
|
||||
server := client.New(ctx,
|
||||
"", "", false, false, false)
|
||||
"", "", false, false, false, false)
|
||||
if err := server.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
65
client/embed/capture.go
Normal file
65
client/embed/capture.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/util/capture"
|
||||
)
|
||||
|
||||
// CaptureOptions configures a packet capture session.
|
||||
type CaptureOptions struct {
|
||||
// Output receives pcap-formatted data. Nil disables pcap output.
|
||||
Output io.Writer
|
||||
// TextOutput receives human-readable packet summaries. Nil disables text output.
|
||||
TextOutput io.Writer
|
||||
// Filter is a BPF-like filter expression (e.g. "host 10.0.0.1 and tcp port 443").
|
||||
// Empty captures all packets.
|
||||
Filter string
|
||||
// Verbose adds seq/ack, TTL, window, and total length to text output.
|
||||
Verbose bool
|
||||
// ASCII dumps transport payload as printable ASCII after each packet line.
|
||||
ASCII bool
|
||||
}
|
||||
|
||||
// CaptureStats reports capture session counters.
|
||||
type CaptureStats struct {
|
||||
Packets int64
|
||||
Bytes int64
|
||||
Dropped int64
|
||||
}
|
||||
|
||||
// CaptureSession represents an active packet capture. Call Stop to end the
|
||||
// capture and flush buffered packets.
|
||||
type CaptureSession struct {
|
||||
sess *capture.Session
|
||||
engine *internal.Engine
|
||||
}
|
||||
|
||||
// Stop ends the capture, flushes remaining packets, and detaches from the device.
|
||||
// Safe to call multiple times.
|
||||
func (cs *CaptureSession) Stop() {
|
||||
if cs.engine != nil {
|
||||
_ = cs.engine.SetCapture(nil)
|
||||
cs.engine = nil
|
||||
}
|
||||
if cs.sess != nil {
|
||||
cs.sess.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// Stats returns current capture counters.
|
||||
func (cs *CaptureSession) Stats() CaptureStats {
|
||||
s := cs.sess.Stats()
|
||||
return CaptureStats{
|
||||
Packets: s.Packets,
|
||||
Bytes: s.Bytes,
|
||||
Dropped: s.Dropped,
|
||||
}
|
||||
}
|
||||
|
||||
// Done returns a channel that is closed when the capture's writer goroutine
|
||||
// has fully exited and all buffered packets have been flushed.
|
||||
func (cs *CaptureSession) Done() <-chan struct{} {
|
||||
return cs.sess.Done()
|
||||
}
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util/capture"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -65,7 +66,7 @@ type Options struct {
|
||||
PrivateKey string
|
||||
// ManagementURL overrides the default management server URL
|
||||
ManagementURL string
|
||||
// PreSharedKey is the pre-shared key for the WireGuard interface
|
||||
// PreSharedKey is the pre-shared key for the tunnel interface
|
||||
PreSharedKey string
|
||||
// LogOutput is the output destination for logs (defaults to os.Stderr if nil)
|
||||
LogOutput io.Writer
|
||||
@@ -81,9 +82,9 @@ type Options struct {
|
||||
DisableClientRoutes bool
|
||||
// BlockInbound blocks all inbound connections from peers
|
||||
BlockInbound bool
|
||||
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
|
||||
// WireguardPort is the port for the tunnel interface. Use 0 for a random port.
|
||||
WireguardPort *int
|
||||
// MTU is the MTU for the WireGuard interface.
|
||||
// MTU is the MTU for the tunnel interface.
|
||||
// Valid values are in the range 576..8192 bytes.
|
||||
// If non-nil, this value overrides any value stored in the config file.
|
||||
// If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280.
|
||||
@@ -469,6 +470,52 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
|
||||
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
||||
}
|
||||
|
||||
// StartCapture begins capturing packets on this client's tunnel device.
|
||||
// Only one capture can be active at a time; starting a new one stops the previous.
|
||||
// Call StopCapture (or CaptureSession.Stop) to end it.
|
||||
func (c *Client) StartCapture(opts CaptureOptions) (*CaptureSession, error) {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var matcher capture.Matcher
|
||||
if opts.Filter != "" {
|
||||
m, err := capture.ParseFilter(opts.Filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse filter: %w", err)
|
||||
}
|
||||
matcher = m
|
||||
}
|
||||
|
||||
sess, err := capture.NewSession(capture.Options{
|
||||
Output: opts.Output,
|
||||
TextOutput: opts.TextOutput,
|
||||
Matcher: matcher,
|
||||
Verbose: opts.Verbose,
|
||||
ASCII: opts.ASCII,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create capture session: %w", err)
|
||||
}
|
||||
|
||||
if err := engine.SetCapture(sess); err != nil {
|
||||
sess.Stop()
|
||||
return nil, fmt.Errorf("set capture: %w", err)
|
||||
}
|
||||
|
||||
return &CaptureSession{sess: sess, engine: engine}, nil
|
||||
}
|
||||
|
||||
// StopCapture stops the active capture session if one is running.
|
||||
func (c *Client) StopCapture() error {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return engine.SetCapture(nil)
|
||||
}
|
||||
|
||||
// getEngine safely retrieves the engine from the client with proper locking.
|
||||
// Returns ErrClientNotStarted if the client is not started.
|
||||
// Returns ErrEngineNotStarted if the engine is not available.
|
||||
|
||||
11
client/firewall/firewalld/firewalld.go
Normal file
11
client/firewall/firewalld/firewalld.go
Normal file
@@ -0,0 +1,11 @@
|
||||
// Package firewalld integrates with the firewalld daemon so NetBird can place
|
||||
// its wg interface into firewalld's "trusted" zone. This is required because
|
||||
// firewalld's nftables chains are created with NFT_CHAIN_OWNER on recent
|
||||
// versions, which returns EPERM to any other process that tries to insert
|
||||
// rules into them. The workaround mirrors what Tailscale does: let firewalld
|
||||
// itself add the accept rules to its own chains by trusting the interface.
|
||||
package firewalld
|
||||
|
||||
// TrustedZone is the firewalld zone name used for interfaces whose traffic
|
||||
// should bypass firewalld filtering.
|
||||
const TrustedZone = "trusted"
|
||||
260
client/firewall/firewalld/firewalld_linux.go
Normal file
260
client/firewall/firewalld/firewalld_linux.go
Normal file
@@ -0,0 +1,260 @@
|
||||
//go:build linux
|
||||
|
||||
package firewalld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
dbusDest = "org.fedoraproject.FirewallD1"
|
||||
dbusPath = "/org/fedoraproject/FirewallD1"
|
||||
dbusRootIface = "org.fedoraproject.FirewallD1"
|
||||
dbusZoneIface = "org.fedoraproject.FirewallD1.zone"
|
||||
|
||||
errZoneAlreadySet = "ZONE_ALREADY_SET"
|
||||
errAlreadyEnabled = "ALREADY_ENABLED"
|
||||
errUnknownIface = "UNKNOWN_INTERFACE"
|
||||
errNotEnabled = "NOT_ENABLED"
|
||||
|
||||
// callTimeout bounds each individual DBus or firewall-cmd invocation.
|
||||
// A fresh context is created for each call so a slow DBus probe can't
|
||||
// exhaust the deadline before the firewall-cmd fallback gets to run.
|
||||
callTimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
errDBusUnavailable = errors.New("firewalld dbus unavailable")
|
||||
|
||||
// trustLogOnce ensures the "added to trusted zone" message is logged at
|
||||
// Info level only for the first successful add per process; repeat adds
|
||||
// from other init paths are quieter.
|
||||
trustLogOnce sync.Once
|
||||
|
||||
parentCtxMu sync.RWMutex
|
||||
parentCtx context.Context = context.Background()
|
||||
)
|
||||
|
||||
// SetParentContext installs a parent context whose cancellation aborts any
|
||||
// in-flight TrustInterface call. It does not affect UntrustInterface, which
|
||||
// always uses a fresh Background-rooted timeout so cleanup can still run
|
||||
// during engine shutdown when the engine context is already cancelled.
|
||||
func SetParentContext(ctx context.Context) {
|
||||
parentCtxMu.Lock()
|
||||
parentCtx = ctx
|
||||
parentCtxMu.Unlock()
|
||||
}
|
||||
|
||||
func getParentContext() context.Context {
|
||||
parentCtxMu.RLock()
|
||||
defer parentCtxMu.RUnlock()
|
||||
return parentCtx
|
||||
}
|
||||
|
||||
// TrustInterface places iface into firewalld's trusted zone if firewalld is
|
||||
// running. It is idempotent and best-effort: errors are returned so callers
|
||||
// can log, but a non-running firewalld is not an error. Only the first
|
||||
// successful call per process logs at Info. Respects the parent context set
|
||||
// via SetParentContext so startup-time cancellation unblocks it.
|
||||
func TrustInterface(iface string) error {
|
||||
parent := getParentContext()
|
||||
if !isRunning(parent) {
|
||||
return nil
|
||||
}
|
||||
if err := addTrusted(parent, iface); err != nil {
|
||||
return fmt.Errorf("add %s to firewalld trusted zone: %w", iface, err)
|
||||
}
|
||||
trustLogOnce.Do(func() {
|
||||
log.Infof("added %s to firewalld trusted zone", iface)
|
||||
})
|
||||
log.Debugf("firewalld: ensured %s is in trusted zone", iface)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UntrustInterface removes iface from firewalld's trusted zone if firewalld
|
||||
// is running. Idempotent. Uses a Background-rooted timeout so it still runs
|
||||
// during shutdown after the engine context has been cancelled.
|
||||
func UntrustInterface(iface string) error {
|
||||
if !isRunning(context.Background()) {
|
||||
return nil
|
||||
}
|
||||
if err := removeTrusted(context.Background(), iface); err != nil {
|
||||
return fmt.Errorf("remove %s from firewalld trusted zone: %w", iface, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newCallContext(parent context.Context) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(parent, callTimeout)
|
||||
}
|
||||
|
||||
func isRunning(parent context.Context) bool {
|
||||
ctx, cancel := newCallContext(parent)
|
||||
ok, err := isRunningDBus(ctx)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return ok
|
||||
}
|
||||
if errors.Is(err, errDBusUnavailable) || errors.Is(err, context.DeadlineExceeded) {
|
||||
ctx, cancel = newCallContext(parent)
|
||||
defer cancel()
|
||||
return isRunningCLI(ctx)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func addTrusted(parent context.Context, iface string) error {
|
||||
ctx, cancel := newCallContext(parent)
|
||||
err := addDBus(ctx, iface)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if !errors.Is(err, errDBusUnavailable) {
|
||||
log.Debugf("firewalld: dbus add failed, falling back to firewall-cmd: %v", err)
|
||||
}
|
||||
ctx, cancel = newCallContext(parent)
|
||||
defer cancel()
|
||||
return addCLI(ctx, iface)
|
||||
}
|
||||
|
||||
func removeTrusted(parent context.Context, iface string) error {
|
||||
ctx, cancel := newCallContext(parent)
|
||||
err := removeDBus(ctx, iface)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if !errors.Is(err, errDBusUnavailable) {
|
||||
log.Debugf("firewalld: dbus remove failed, falling back to firewall-cmd: %v", err)
|
||||
}
|
||||
ctx, cancel = newCallContext(parent)
|
||||
defer cancel()
|
||||
return removeCLI(ctx, iface)
|
||||
}
|
||||
|
||||
func isRunningDBus(ctx context.Context) (bool, error) {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||
}
|
||||
obj := conn.Object(dbusDest, dbusPath)
|
||||
|
||||
var zone string
|
||||
if err := obj.CallWithContext(ctx, dbusRootIface+".getDefaultZone", 0).Store(&zone); err != nil {
|
||||
return false, fmt.Errorf("firewalld getDefaultZone: %w", err)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func isRunningCLI(ctx context.Context) bool {
|
||||
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||
return false
|
||||
}
|
||||
return exec.CommandContext(ctx, "firewall-cmd", "--state").Run() == nil
|
||||
}
|
||||
|
||||
func addDBus(ctx context.Context, iface string) error {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||
}
|
||||
obj := conn.Object(dbusDest, dbusPath)
|
||||
|
||||
call := obj.CallWithContext(ctx, dbusZoneIface+".addInterface", 0, TrustedZone, iface)
|
||||
if call.Err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if dbusErrContains(call.Err, errAlreadyEnabled) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if dbusErrContains(call.Err, errZoneAlreadySet) {
|
||||
move := obj.CallWithContext(ctx, dbusZoneIface+".changeZoneOfInterface", 0, TrustedZone, iface)
|
||||
if move.Err != nil {
|
||||
return fmt.Errorf("firewalld changeZoneOfInterface: %w", move.Err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("firewalld addInterface: %w", call.Err)
|
||||
}
|
||||
|
||||
func removeDBus(ctx context.Context, iface string) error {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||
}
|
||||
obj := conn.Object(dbusDest, dbusPath)
|
||||
|
||||
call := obj.CallWithContext(ctx, dbusZoneIface+".removeInterface", 0, TrustedZone, iface)
|
||||
if call.Err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if dbusErrContains(call.Err, errUnknownIface) || dbusErrContains(call.Err, errNotEnabled) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("firewalld removeInterface: %w", call.Err)
|
||||
}
|
||||
|
||||
func addCLI(ctx context.Context, iface string) error {
|
||||
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||
return fmt.Errorf("firewall-cmd not available: %w", err)
|
||||
}
|
||||
|
||||
// --change-interface (no --permanent) binds the interface for the
|
||||
// current runtime only; we do not want membership to persist across
|
||||
// reboots because netbird re-asserts it on every startup.
|
||||
out, err := exec.CommandContext(ctx,
|
||||
"firewall-cmd", "--zone="+TrustedZone, "--change-interface="+iface,
|
||||
).CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("firewall-cmd change-interface: %w: %s", err, strings.TrimSpace(string(out)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeCLI(ctx context.Context, iface string) error {
|
||||
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||
return fmt.Errorf("firewall-cmd not available: %w", err)
|
||||
}
|
||||
|
||||
out, err := exec.CommandContext(ctx,
|
||||
"firewall-cmd", "--zone="+TrustedZone, "--remove-interface="+iface,
|
||||
).CombinedOutput()
|
||||
if err != nil {
|
||||
msg := strings.TrimSpace(string(out))
|
||||
if strings.Contains(msg, errUnknownIface) || strings.Contains(msg, errNotEnabled) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("firewall-cmd remove-interface: %w: %s", err, msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func dbusErrContains(err error, code string) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var de dbus.Error
|
||||
if errors.As(err, &de) {
|
||||
for _, b := range de.Body {
|
||||
if s, ok := b.(string); ok && strings.Contains(s, code) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Contains(err.Error(), code)
|
||||
}
|
||||
49
client/firewall/firewalld/firewalld_linux_test.go
Normal file
49
client/firewall/firewalld/firewalld_linux_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
//go:build linux
|
||||
|
||||
package firewalld
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
)
|
||||
|
||||
func TestDBusErrContains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
code string
|
||||
want bool
|
||||
}{
|
||||
{"nil error", nil, errZoneAlreadySet, false},
|
||||
{"plain error match", errors.New("ZONE_ALREADY_SET: wt0"), errZoneAlreadySet, true},
|
||||
{"plain error miss", errors.New("something else"), errZoneAlreadySet, false},
|
||||
{
|
||||
"dbus.Error body match",
|
||||
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"ZONE_ALREADY_SET: wt0"}},
|
||||
errZoneAlreadySet,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"dbus.Error body miss",
|
||||
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"INVALID_INTERFACE"}},
|
||||
errAlreadyEnabled,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"dbus.Error non-string body falls back to Error()",
|
||||
dbus.Error{Name: "x", Body: []any{123}},
|
||||
"x",
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := dbusErrContains(tc.err, tc.code)
|
||||
if got != tc.want {
|
||||
t.Fatalf("dbusErrContains(%v, %q) = %v; want %v", tc.err, tc.code, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
25
client/firewall/firewalld/firewalld_other.go
Normal file
25
client/firewall/firewalld/firewalld_other.go
Normal file
@@ -0,0 +1,25 @@
|
||||
//go:build !linux
|
||||
|
||||
package firewalld
|
||||
|
||||
import "context"
|
||||
|
||||
// SetParentContext is a no-op on non-Linux platforms because firewalld only
|
||||
// runs on Linux.
|
||||
func SetParentContext(context.Context) {
|
||||
// intentionally empty: firewalld is a Linux-only daemon
|
||||
}
|
||||
|
||||
// TrustInterface is a no-op on non-Linux platforms because firewalld only
|
||||
// runs on Linux.
|
||||
func TrustInterface(string) error {
|
||||
// intentionally empty: firewalld is a Linux-only daemon
|
||||
return nil
|
||||
}
|
||||
|
||||
// UntrustInterface is a no-op on non-Linux platforms because firewalld only
|
||||
// runs on Linux.
|
||||
func UntrustInterface(string) error {
|
||||
// intentionally empty: firewalld is a Linux-only daemon
|
||||
return nil
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@@ -86,6 +87,12 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
|
||||
}
|
||||
|
||||
// Trust after all fatal init steps so a later failure doesn't leave the
|
||||
// interface in firewalld's trusted zone without a corresponding Close.
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
// persist early to ensure cleanup of chains
|
||||
go func() {
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
@@ -191,6 +198,12 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
||||
}
|
||||
|
||||
// Appending to merr intentionally blocks DeleteState below so ShutdownState
|
||||
// stays persisted and the crash-recovery path retries firewalld cleanup.
|
||||
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
// attempt to delete state only if all other operations succeeded
|
||||
if merr == nil {
|
||||
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||
@@ -217,6 +230,11 @@ func (m *Manager) AllowNetbird() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||
}
|
||||
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@@ -217,6 +218,10 @@ func (m *Manager) AllowNetbird() error {
|
||||
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
||||
}
|
||||
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||
@@ -40,6 +41,8 @@ const (
|
||||
chainNameForward = "FORWARD"
|
||||
chainNameMangleForward = "netbird-mangle-forward"
|
||||
|
||||
firewalldTableName = "firewalld"
|
||||
|
||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||
userDataAcceptInputRule = "inputaccept"
|
||||
@@ -133,6 +136,10 @@ func (r *router) Reset() error {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
|
||||
}
|
||||
|
||||
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.removeNatPreroutingRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
||||
}
|
||||
@@ -280,6 +287,10 @@ func (r *router) createContainers() error {
|
||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||
}
|
||||
|
||||
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
log.Errorf("failed to refresh rules: %s", err)
|
||||
}
|
||||
@@ -1319,6 +1330,13 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip firewalld-owned chains. Firewalld creates its chains with the
|
||||
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
|
||||
// We delegate acceptance to firewalld by trusting the interface instead.
|
||||
if chain.Table.Name == firewalldTableName {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip all iptables-managed tables in the ip family
|
||||
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
||||
return false
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -16,6 +19,9 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.Close(stateManager)
|
||||
}
|
||||
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to untrust interface in firewalld: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -24,5 +30,8 @@ func (m *Manager) AllowNetbird() error {
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AllowNetbird()
|
||||
}
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
Name() string
|
||||
SetFilter(device.PacketFilter) error
|
||||
Address() wgaddr.Address
|
||||
GetWGDevice() *wgdevice.Device
|
||||
|
||||
@@ -115,12 +115,13 @@ type Manager struct {
|
||||
|
||||
localipmanager *localIPManager
|
||||
|
||||
udpTracker *conntrack.UDPTracker
|
||||
icmpTracker *conntrack.ICMPTracker
|
||||
tcpTracker *conntrack.TCPTracker
|
||||
forwarder atomic.Pointer[forwarder.Forwarder]
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
udpTracker *conntrack.UDPTracker
|
||||
icmpTracker *conntrack.ICMPTracker
|
||||
tcpTracker *conntrack.TCPTracker
|
||||
forwarder atomic.Pointer[forwarder.Forwarder]
|
||||
pendingCapture atomic.Pointer[forwarder.PacketCapture]
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
|
||||
blockRule firewall.Rule
|
||||
|
||||
@@ -351,6 +352,19 @@ func (m *Manager) determineRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetPacketCapture sets or clears packet capture on the forwarder endpoint.
|
||||
// This captures outbound response packets that bypass the FilteredDevice in netstack mode.
|
||||
func (m *Manager) SetPacketCapture(pc forwarder.PacketCapture) {
|
||||
if pc == nil {
|
||||
m.pendingCapture.Store(nil)
|
||||
} else {
|
||||
m.pendingCapture.Store(&pc)
|
||||
}
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.SetCapture(pc)
|
||||
}
|
||||
}
|
||||
|
||||
// initForwarder initializes the forwarder, it disables routing on errors
|
||||
func (m *Manager) initForwarder() error {
|
||||
if m.forwarder.Load() != nil {
|
||||
@@ -372,6 +386,11 @@ func (m *Manager) initForwarder() error {
|
||||
|
||||
m.forwarder.Store(forwarder)
|
||||
|
||||
// Re-load after store: a concurrent SetPacketCapture may have seen forwarder as nil and only updated pendingCapture.
|
||||
if pc := m.pendingCapture.Load(); pc != nil {
|
||||
forwarder.SetCapture(*pc)
|
||||
}
|
||||
|
||||
log.Debug("forwarder initialized")
|
||||
|
||||
return nil
|
||||
@@ -614,6 +633,7 @@ func (m *Manager) resetState() {
|
||||
}
|
||||
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.SetCapture(nil)
|
||||
fwder.Stop()
|
||||
}
|
||||
|
||||
|
||||
@@ -31,12 +31,20 @@ var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
type IFaceMock struct {
|
||||
NameFunc func() string
|
||||
SetFilterFunc func(device.PacketFilter) error
|
||||
AddressFunc func() wgaddr.Address
|
||||
GetWGDeviceFunc func() *wgdevice.Device
|
||||
GetDeviceFunc func() *device.FilteredDevice
|
||||
}
|
||||
|
||||
func (i *IFaceMock) Name() string {
|
||||
if i.NameFunc == nil {
|
||||
return "wgtest"
|
||||
}
|
||||
return i.NameFunc()
|
||||
}
|
||||
|
||||
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
|
||||
if i.GetWGDeviceFunc == nil {
|
||||
return nil
|
||||
|
||||
@@ -12,12 +12,19 @@ import (
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
// PacketCapture captures raw packets for debugging. Implementations must be
|
||||
// safe for concurrent use and must not block.
|
||||
type PacketCapture interface {
|
||||
Offer(data []byte, outbound bool)
|
||||
}
|
||||
|
||||
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
|
||||
type endpoint struct {
|
||||
logger *nblog.Logger
|
||||
dispatcher stack.NetworkDispatcher
|
||||
device *wgdevice.Device
|
||||
mtu atomic.Uint32
|
||||
capture atomic.Pointer[PacketCapture]
|
||||
}
|
||||
|
||||
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||
@@ -54,13 +61,17 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error)
|
||||
continue
|
||||
}
|
||||
|
||||
// Send the packet through WireGuard
|
||||
pktBytes := data.AsSlice()
|
||||
|
||||
address := netHeader.DestinationAddress()
|
||||
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
|
||||
if err != nil {
|
||||
if err := e.device.CreateOutboundPacket(pktBytes, address.AsSlice()); err != nil {
|
||||
e.logger.Error1("CreateOutboundPacket: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if pc := e.capture.Load(); pc != nil {
|
||||
(*pc).Offer(pktBytes, true)
|
||||
}
|
||||
written++
|
||||
}
|
||||
|
||||
|
||||
@@ -139,6 +139,16 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// SetCapture sets or clears the packet capture on the forwarder endpoint.
|
||||
// This captures outbound packets that bypass the FilteredDevice (netstack forwarding).
|
||||
func (f *Forwarder) SetCapture(pc PacketCapture) {
|
||||
if pc == nil {
|
||||
f.endpoint.capture.Store(nil)
|
||||
return
|
||||
}
|
||||
f.endpoint.capture.Store(&pc)
|
||||
}
|
||||
|
||||
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
||||
if len(payload) < header.IPv4MinimumSize {
|
||||
return fmt.Errorf("packet too small: %d bytes", len(payload))
|
||||
|
||||
@@ -270,5 +270,9 @@ func (f *Forwarder) injectICMPReply(id stack.TransportEndpointID, icmpPayload []
|
||||
return 0
|
||||
}
|
||||
|
||||
if pc := f.endpoint.capture.Load(); pc != nil {
|
||||
(*pc).Offer(fullPacket, true)
|
||||
}
|
||||
|
||||
return len(fullPacket)
|
||||
}
|
||||
|
||||
@@ -239,8 +239,12 @@ func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
|
||||
ipv6Count++
|
||||
}
|
||||
|
||||
assert.Equal(t, packetsPerFamily, ipv4Count)
|
||||
assert.Equal(t, packetsPerFamily, ipv6Count)
|
||||
// Allow some UDP packet loss under load (e.g. FreeBSD/QEMU runners). The
|
||||
// routing-correctness checks above are the real assertions; the counts
|
||||
// are a sanity bound to catch a totally silent path.
|
||||
minDelivered := packetsPerFamily * 80 / 100
|
||||
assert.GreaterOrEqual(t, ipv4Count, minDelivered, "IPv4 delivery below threshold")
|
||||
assert.GreaterOrEqual(t, ipv6Count, minDelivered, "IPv6 delivery below threshold")
|
||||
}
|
||||
|
||||
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
|
||||
|
||||
@@ -3,6 +3,7 @@ package device
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
@@ -28,11 +29,20 @@ type PacketFilter interface {
|
||||
SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
|
||||
}
|
||||
|
||||
// PacketCapture captures raw packets for debugging. Implementations must be
|
||||
// safe for concurrent use and must not block.
|
||||
type PacketCapture interface {
|
||||
// Offer submits a packet for capture. outbound is true for packets
|
||||
// leaving the host (Read path), false for packets arriving (Write path).
|
||||
Offer(data []byte, outbound bool)
|
||||
}
|
||||
|
||||
// FilteredDevice to override Read or Write of packets
|
||||
type FilteredDevice struct {
|
||||
tun.Device
|
||||
|
||||
filter PacketFilter
|
||||
capture atomic.Pointer[PacketCapture]
|
||||
mutex sync.RWMutex
|
||||
closeOnce sync.Once
|
||||
}
|
||||
@@ -63,20 +73,25 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
|
||||
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
d.mutex.RLock()
|
||||
filter := d.filter
|
||||
d.mutex.RUnlock()
|
||||
|
||||
if filter == nil {
|
||||
return
|
||||
if filter != nil {
|
||||
for i := 0; i < n; i++ {
|
||||
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
||||
bufs = append(bufs[:i], bufs[i+1:]...)
|
||||
sizes = append(sizes[:i], sizes[i+1:]...)
|
||||
n--
|
||||
i--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
||||
bufs = append(bufs[:i], bufs[i+1:]...)
|
||||
sizes = append(sizes[:i], sizes[i+1:]...)
|
||||
n--
|
||||
i--
|
||||
if pc := d.capture.Load(); pc != nil {
|
||||
for i := 0; i < n; i++ {
|
||||
(*pc).Offer(bufs[i][offset:offset+sizes[i]], true)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,6 +100,13 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
|
||||
|
||||
// Write wraps write method with filtering feature
|
||||
func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
// Capture before filtering so dropped packets are still visible in captures.
|
||||
if pc := d.capture.Load(); pc != nil {
|
||||
for _, buf := range bufs {
|
||||
(*pc).Offer(buf[offset:], false)
|
||||
}
|
||||
}
|
||||
|
||||
d.mutex.RLock()
|
||||
filter := d.filter
|
||||
d.mutex.RUnlock()
|
||||
@@ -96,9 +118,10 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
filteredBufs := make([][]byte, 0, len(bufs))
|
||||
dropped := 0
|
||||
for _, buf := range bufs {
|
||||
if !filter.FilterInbound(buf[offset:], len(buf)) {
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
if filter.FilterInbound(buf[offset:], len(buf)) {
|
||||
dropped++
|
||||
} else {
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,3 +136,14 @@ func (d *FilteredDevice) SetFilter(filter PacketFilter) {
|
||||
d.filter = filter
|
||||
d.mutex.Unlock()
|
||||
}
|
||||
|
||||
// SetCapture sets or clears the packet capture sink. Pass nil to disable.
|
||||
// Uses atomic store so the hot path (Read/Write) is a single pointer load
|
||||
// with no locking overhead when capture is off.
|
||||
func (d *FilteredDevice) SetCapture(pc PacketCapture) {
|
||||
if pc == nil {
|
||||
d.capture.Store(nil)
|
||||
return
|
||||
}
|
||||
d.capture.Store(&pc)
|
||||
}
|
||||
|
||||
@@ -158,7 +158,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
if n != 0 {
|
||||
if n != 1 {
|
||||
t.Errorf("expected n=1, got %d", n)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -61,6 +61,7 @@ allocs.prof: Allocations profiling information.
|
||||
threadcreate.prof: Thread creation profiling information.
|
||||
cpu.prof: CPU profiling information.
|
||||
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
|
||||
capture.pcap: Packet capture in pcap format. Only present when capture was running during bundle collection. Omitted from anonymized bundles because it contains raw decrypted packet data.
|
||||
|
||||
|
||||
Anonymization Process
|
||||
@@ -234,6 +235,7 @@ type BundleGenerator struct {
|
||||
logPath string
|
||||
tempDir string
|
||||
cpuProfile []byte
|
||||
capturePath string
|
||||
refreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
clientMetrics MetricsExporter
|
||||
|
||||
@@ -257,7 +259,8 @@ type GeneratorDependencies struct {
|
||||
LogPath string
|
||||
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
|
||||
CPUProfile []byte
|
||||
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
CapturePath string
|
||||
RefreshStatus func()
|
||||
ClientMetrics MetricsExporter
|
||||
}
|
||||
|
||||
@@ -277,6 +280,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
||||
logPath: deps.LogPath,
|
||||
tempDir: deps.TempDir,
|
||||
cpuProfile: deps.CPUProfile,
|
||||
capturePath: deps.CapturePath,
|
||||
refreshStatus: deps.RefreshStatus,
|
||||
clientMetrics: deps.ClientMetrics,
|
||||
|
||||
@@ -346,6 +350,10 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("failed to add CPU profile to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addCaptureFile(); err != nil {
|
||||
log.Errorf("failed to add capture file to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addStackTrace(); err != nil {
|
||||
log.Errorf("failed to add stack trace to debug bundle: %v", err)
|
||||
}
|
||||
@@ -669,6 +677,29 @@ func (g *BundleGenerator) addCPUProfile() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addCaptureFile() error {
|
||||
if g.capturePath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if g.anonymize {
|
||||
log.Info("skipping capture file in anonymized bundle (contains raw packet data)")
|
||||
return nil
|
||||
}
|
||||
|
||||
f, err := os.Open(g.capturePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open capture file: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := g.addFileToZip(f, "capture.pcap"); err != nil {
|
||||
return fmt.Errorf("add capture file to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addStackTrace() error {
|
||||
buf := make([]byte, 5242880) // 5 MB buffer
|
||||
n := runtime.Stack(buf, true)
|
||||
|
||||
@@ -3,10 +3,12 @@ package debug
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -19,8 +21,10 @@ func TestUpload(t *testing.T) {
|
||||
t.Skip("Skipping upload test on docker ci")
|
||||
}
|
||||
testDir := t.TempDir()
|
||||
testURL := "http://localhost:8080"
|
||||
addr := reserveLoopbackPort(t)
|
||||
testURL := "http://" + addr
|
||||
t.Setenv("SERVER_URL", testURL)
|
||||
t.Setenv("SERVER_ADDRESS", addr)
|
||||
t.Setenv("STORE_DIR", testDir)
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
@@ -33,6 +37,7 @@ func TestUpload(t *testing.T) {
|
||||
t.Errorf("Failed to stop server: %v", err)
|
||||
}
|
||||
})
|
||||
waitForServer(t, addr)
|
||||
|
||||
file := filepath.Join(t.TempDir(), "tmpfile")
|
||||
fileContent := []byte("test file content")
|
||||
@@ -47,3 +52,30 @@ func TestUpload(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fileContent, createdFileContent)
|
||||
}
|
||||
|
||||
// reserveLoopbackPort binds an ephemeral port on loopback to learn a free
|
||||
// address, then releases it so the server under test can rebind. The close/
|
||||
// rebind window is racy in theory; on loopback with a kernel-assigned port
|
||||
// it's essentially never contended in practice.
|
||||
func reserveLoopbackPort(t *testing.T) string {
|
||||
t.Helper()
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := l.Addr().String()
|
||||
require.NoError(t, l.Close())
|
||||
return addr
|
||||
}
|
||||
|
||||
func waitForServer(t *testing.T, addr string) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(5 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
|
||||
if err == nil {
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("server did not start listening on %s in time", addr)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
const (
|
||||
defaultResolvConfPath = "/etc/resolv.conf"
|
||||
nsswitchConfPath = "/etc/nsswitch.conf"
|
||||
)
|
||||
|
||||
type resolvConf struct {
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -192,6 +195,12 @@ func (c *HandlerChain) logHandlers() {
|
||||
}
|
||||
|
||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
c.dispatch(w, r, math.MaxInt)
|
||||
}
|
||||
|
||||
// dispatch routes a DNS request through the chain, skipping handlers with
|
||||
// priority > maxPriority. Shared by ServeDNS and ResolveInternal.
|
||||
func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) {
|
||||
if len(r.Question) == 0 {
|
||||
return
|
||||
}
|
||||
@@ -216,6 +225,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
// Try handlers in priority order
|
||||
for _, entry := range handlers {
|
||||
if entry.Priority > maxPriority {
|
||||
continue
|
||||
}
|
||||
if !c.isHandlerMatch(qname, entry) {
|
||||
continue
|
||||
}
|
||||
@@ -273,6 +285,55 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
|
||||
cw.response.Len(), meta, time.Since(startTime))
|
||||
}
|
||||
|
||||
// ResolveInternal runs an in-process DNS query against the chain, skipping any
|
||||
// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt
|
||||
// cache refresher) that must bypass themselves to avoid loops. Honors ctx
|
||||
// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own
|
||||
// (bounded by the invoked handler's internal timeout).
|
||||
func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) {
|
||||
if len(r.Question) == 0 {
|
||||
return nil, fmt.Errorf("empty question")
|
||||
}
|
||||
|
||||
base := &internalResponseWriter{}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
c.dispatch(base, r, maxPriority)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
// Prefer a completed response if dispatch finished concurrently with cancellation.
|
||||
select {
|
||||
case <-done:
|
||||
default:
|
||||
return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
if base.response == nil || base.response.Rcode == dns.RcodeRefused {
|
||||
return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d",
|
||||
strings.ToLower(r.Question[0].Name), maxPriority)
|
||||
}
|
||||
return base.response, nil
|
||||
}
|
||||
|
||||
// HasRootHandlerAtOrBelow reports whether any "." handler is registered at
|
||||
// priority ≤ maxPriority.
|
||||
func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
for _, h := range c.handlers {
|
||||
if h.Pattern == "." && h.Priority <= maxPriority {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
switch {
|
||||
case entry.Pattern == ".":
|
||||
@@ -291,3 +352,36 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// internalResponseWriter captures a dns.Msg for in-process chain queries.
|
||||
type internalResponseWriter struct {
|
||||
response *dns.Msg
|
||||
}
|
||||
|
||||
func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil }
|
||||
func (w *internalResponseWriter) LocalAddr() net.Addr { return nil }
|
||||
func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||
|
||||
// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg
|
||||
// still surface their answer to ResolveInternal.
|
||||
func (w *internalResponseWriter) Write(p []byte) (int, error) {
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(p); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w.response = msg
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *internalResponseWriter) Close() error { return nil }
|
||||
func (w *internalResponseWriter) TsigStatus() error { return nil }
|
||||
|
||||
// TsigTimersOnly is part of dns.ResponseWriter.
|
||||
func (w *internalResponseWriter) TsigTimersOnly(bool) {
|
||||
// no-op: in-process queries carry no TSIG state.
|
||||
}
|
||||
|
||||
// Hijack is part of dns.ResponseWriter.
|
||||
func (w *internalResponseWriter) Hijack() {
|
||||
// no-op: in-process queries have no underlying connection to hand off.
|
||||
}
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package dns_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
@@ -1042,3 +1046,163 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// answeringHandler writes a fixed A record to ack the query. Used to verify
|
||||
// which handler ResolveInternal dispatches to.
|
||||
type answeringHandler struct {
|
||||
name string
|
||||
ip string
|
||||
}
|
||||
|
||||
func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
resp.Answer = []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP(h.ip).To4(),
|
||||
}}
|
||||
_ = w.WriteMsg(resp)
|
||||
}
|
||||
|
||||
func (h *answeringHandler) String() string { return h.name }
|
||||
|
||||
func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
|
||||
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
|
||||
low := &answeringHandler{name: "low", ip: "10.0.0.2"}
|
||||
|
||||
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
|
||||
chain.AddHandler("example.com.", low, nbdns.PriorityUpstream)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, 1, len(resp.Answer))
|
||||
a, ok := resp.Answer[0].(*dns.A)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream")
|
||||
}
|
||||
|
||||
func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
|
||||
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
_, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||
assert.Error(t, err, "no handler at or below maxPriority should error")
|
||||
}
|
||||
|
||||
// rawWriteHandler packs a response and calls ResponseWriter.Write directly
|
||||
// (instead of WriteMsg), exercising the internalResponseWriter.Write path.
|
||||
type rawWriteHandler struct {
|
||||
ip string
|
||||
}
|
||||
|
||||
func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
resp.Answer = []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP(h.ip).To4(),
|
||||
}}
|
||||
packed, err := resp.Pack()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = w.Write(packed)
|
||||
}
|
||||
|
||||
func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||
assert.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.Answer, 1)
|
||||
a, ok := resp.Answer[0].(*dns.A)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer")
|
||||
}
|
||||
|
||||
func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
_, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// hangingHandler blocks indefinitely until closed, simulating a wedged upstream.
|
||||
type hangingHandler struct {
|
||||
block chan struct{}
|
||||
}
|
||||
|
||||
func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
<-h.block
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
_ = w.WriteMsg(resp)
|
||||
}
|
||||
|
||||
func (h *hangingHandler) String() string { return "hangingHandler" }
|
||||
|
||||
func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
h := &hangingHandler{block: make(chan struct{})}
|
||||
defer close(h.block)
|
||||
|
||||
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
_, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline")
|
||||
}
|
||||
|
||||
func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
h := &answeringHandler{name: "h", ip: "10.0.0.1"}
|
||||
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain")
|
||||
|
||||
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count")
|
||||
|
||||
chain.AddHandler(".", h, nbdns.PriorityMgmtCache)
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded")
|
||||
|
||||
chain.AddHandler(".", h, nbdns.PriorityDefault)
|
||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included")
|
||||
|
||||
chain.RemoveHandler(".", nbdns.PriorityDefault)
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
|
||||
|
||||
// Primary nsgroup case: root handler lands at PriorityUpstream.
|
||||
chain.AddHandler(".", h, nbdns.PriorityUpstream)
|
||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included")
|
||||
chain.RemoveHandler(".", nbdns.PriorityUpstream)
|
||||
|
||||
// Fallback case: original /etc/resolv.conf entries land at PriorityFallback.
|
||||
chain.AddHandler(".", h, nbdns.PriorityFallback)
|
||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included")
|
||||
chain.RemoveHandler(".", nbdns.PriorityFallback)
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
|
||||
}
|
||||
|
||||
@@ -46,12 +46,12 @@ type restoreHostManager interface {
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface string) (hostManager, error) {
|
||||
osManager, err := getOSDNSManagerType()
|
||||
osManager, reason, err := getOSDNSManagerType()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get os dns manager type: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("System DNS manager discovered: %s", osManager)
|
||||
log.Infof("System DNS manager discovered: %s (%s)", osManager, reason)
|
||||
mgr, err := newHostManagerFromType(wgInterface, osManager)
|
||||
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
|
||||
if err != nil {
|
||||
@@ -74,17 +74,49 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor
|
||||
}
|
||||
}
|
||||
|
||||
func getOSDNSManagerType() (osManagerType, error) {
|
||||
func getOSDNSManagerType() (osManagerType, string, error) {
|
||||
resolved := isSystemdResolvedRunning()
|
||||
nss := isLibnssResolveUsed()
|
||||
stub := checkStub()
|
||||
|
||||
// Prefer systemd-resolved whenever it owns libc resolution, regardless of
|
||||
// who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups
|
||||
// that go through nss-resolve, and in foreign mode they can loop back
|
||||
// through resolved as an upstream.
|
||||
if resolved && (nss || stub) {
|
||||
return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil
|
||||
}
|
||||
|
||||
mgr, reason, rejected, err := scanResolvConfHeader()
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
if reason != "" {
|
||||
return mgr, reason, nil
|
||||
}
|
||||
|
||||
fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub)
|
||||
if len(rejected) > 0 {
|
||||
fallback += "; rejected: " + strings.Join(rejected, ", ")
|
||||
}
|
||||
return fileManager, fallback, nil
|
||||
}
|
||||
|
||||
// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the
|
||||
// matching manager. If reason is empty the caller should pick file mode and
|
||||
// use rejected for diagnostics.
|
||||
func scanResolvConfHeader() (osManagerType, string, []string, error) {
|
||||
file, err := os.Open(defaultResolvConfPath)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
||||
return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := file.Close(); err != nil {
|
||||
log.Errorf("close file %s: %s", defaultResolvConfPath, err)
|
||||
if cerr := file.Close(); cerr != nil {
|
||||
log.Errorf("close file %s: %s", defaultResolvConfPath, cerr)
|
||||
}
|
||||
}()
|
||||
|
||||
var rejected []string
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
text := scanner.Text()
|
||||
@@ -92,41 +124,48 @@ func getOSDNSManagerType() (osManagerType, error) {
|
||||
continue
|
||||
}
|
||||
if text[0] != '#' {
|
||||
return fileManager, nil
|
||||
break
|
||||
}
|
||||
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
||||
return netbirdManager, nil
|
||||
}
|
||||
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||
return networkManager, nil
|
||||
}
|
||||
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
|
||||
if checkStub() {
|
||||
return systemdManager, nil
|
||||
} else {
|
||||
return fileManager, nil
|
||||
}
|
||||
}
|
||||
if strings.Contains(text, "resolvconf") {
|
||||
if isSystemdResolveConfMode() {
|
||||
return systemdManager, nil
|
||||
}
|
||||
|
||||
return resolvConfManager, nil
|
||||
if mgr, reason, rej := matchResolvConfHeader(text); reason != "" {
|
||||
return mgr, reason, nil, nil
|
||||
} else if rej != "" {
|
||||
rejected = append(rejected, rej)
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil && err != io.EOF {
|
||||
return 0, fmt.Errorf("scan: %w", err)
|
||||
return 0, "", nil, fmt.Errorf("scan: %w", err)
|
||||
}
|
||||
|
||||
return fileManager, nil
|
||||
return 0, "", rejected, nil
|
||||
}
|
||||
|
||||
// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager.
|
||||
// matchResolvConfHeader inspects a single comment line. Returns either a
|
||||
// definitive (manager, reason) or a non-empty rejected diagnostic.
|
||||
func matchResolvConfHeader(text string) (osManagerType, string, string) {
|
||||
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
||||
return netbirdManager, "netbird-managed resolv.conf header detected", ""
|
||||
}
|
||||
if strings.Contains(text, "NetworkManager") {
|
||||
if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||
return networkManager, "NetworkManager header + supported version on dbus", ""
|
||||
}
|
||||
return 0, "", "NetworkManager header (no dbus or unsupported version)"
|
||||
}
|
||||
if strings.Contains(text, "resolvconf") {
|
||||
if isSystemdResolveConfMode() {
|
||||
return systemdManager, "resolvconf header in systemd-resolved compatibility mode", ""
|
||||
}
|
||||
return resolvConfManager, "resolvconf header detected", ""
|
||||
}
|
||||
return 0, "", ""
|
||||
}
|
||||
|
||||
// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed
|
||||
// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping
|
||||
// into file mode while resolved is active.
|
||||
func checkStub() bool {
|
||||
rConf, err := parseDefaultResolvConf()
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse resolv conf: %s", err)
|
||||
log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -139,3 +178,36 @@ func checkStub() bool {
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isLibnssResolveUsed reports whether nss-resolve is listed before dns on
|
||||
// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are
|
||||
// delegated to systemd-resolved regardless of /etc/resolv.conf.
|
||||
func isLibnssResolveUsed() bool {
|
||||
bs, err := os.ReadFile(nsswitchConfPath)
|
||||
if err != nil {
|
||||
log.Debugf("read %s: %v", nsswitchConfPath, err)
|
||||
return false
|
||||
}
|
||||
return parseNsswitchResolveAhead(bs)
|
||||
}
|
||||
|
||||
func parseNsswitchResolveAhead(data []byte) bool {
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
if i := strings.IndexByte(line, '#'); i >= 0 {
|
||||
line = line[:i]
|
||||
}
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 || fields[0] != "hosts:" {
|
||||
continue
|
||||
}
|
||||
for _, module := range fields[1:] {
|
||||
switch module {
|
||||
case "dns":
|
||||
return false
|
||||
case "resolve":
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
76
client/internal/dns/host_unix_test.go
Normal file
76
client/internal/dns/host_unix_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseNsswitchResolveAhead(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "resolve before dns with action token",
|
||||
in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "dns before resolve",
|
||||
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "debian default with only dns",
|
||||
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "neither resolve nor dns",
|
||||
in: "hosts: files myhostname\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "no hosts line",
|
||||
in: "passwd: files systemd\ngroup: files systemd\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
in: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "comments and blank lines ignored",
|
||||
in: "# comment\n\n# another\nhosts: resolve dns\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "trailing inline comment",
|
||||
in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "hosts token must be the first field",
|
||||
in: " hosts: resolve dns\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "other db line mentioning resolve is ignored",
|
||||
in: "networks: resolve\nhosts: dns\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "only resolve, no dns",
|
||||
in: "hosts: files resolve\n",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want {
|
||||
t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,40 +2,83 @@ package mgmt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
const dnsTimeout = 5 * time.Second
|
||||
const (
|
||||
dnsTimeout = 5 * time.Second
|
||||
defaultTTL = 300 * time.Second
|
||||
refreshBackoff = 30 * time.Second
|
||||
|
||||
// Resolver caches critical NetBird infrastructure domains
|
||||
// envMgmtCacheTTL overrides defaultTTL for integration/dev testing.
|
||||
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
|
||||
)
|
||||
|
||||
// ChainResolver lets the cache refresh stale entries through the DNS handler
|
||||
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
|
||||
// system resolver.
|
||||
type ChainResolver interface {
|
||||
ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error)
|
||||
HasRootHandlerAtOrBelow(maxPriority int) bool
|
||||
}
|
||||
|
||||
// cachedRecord holds DNS records plus timestamps used for TTL refresh.
|
||||
// records and cachedAt are set at construction and treated as immutable;
|
||||
// lastFailedRefresh and consecFailures are mutable and must be accessed under
|
||||
// Resolver.mutex.
|
||||
type cachedRecord struct {
|
||||
records []dns.RR
|
||||
cachedAt time.Time
|
||||
lastFailedRefresh time.Time
|
||||
consecFailures int
|
||||
}
|
||||
|
||||
// Resolver caches critical NetBird infrastructure domains.
|
||||
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
|
||||
type Resolver struct {
|
||||
records map[dns.Question][]dns.RR
|
||||
records map[dns.Question]*cachedRecord
|
||||
mgmtDomain *domain.Domain
|
||||
serverDomains *dnsconfig.ServerDomains
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
type ipsResponse struct {
|
||||
ips []netip.Addr
|
||||
err error
|
||||
chain ChainResolver
|
||||
chainMaxPriority int
|
||||
refreshGroup singleflight.Group
|
||||
|
||||
// refreshing tracks questions whose refresh is running via the OS
|
||||
// fallback path. A ServeDNS hit for a question in this map indicates
|
||||
// the OS resolver routed the recursive query back to us (loop). Only
|
||||
// the OS path arms this so chain-path refreshes don't produce false
|
||||
// positives. The atomic bool is CAS-flipped once per refresh to
|
||||
// throttle the warning log.
|
||||
refreshing map[dns.Question]*atomic.Bool
|
||||
|
||||
cacheTTL time.Duration
|
||||
}
|
||||
|
||||
// NewResolver creates a new management domains cache resolver.
|
||||
func NewResolver() *Resolver {
|
||||
return &Resolver{
|
||||
records: make(map[dns.Question][]dns.RR),
|
||||
records: make(map[dns.Question]*cachedRecord),
|
||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||
cacheTTL: resolveCacheTTL(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,7 +87,19 @@ func (m *Resolver) String() string {
|
||||
return "MgmtCacheResolver"
|
||||
}
|
||||
|
||||
// ServeDNS implements dns.Handler interface.
|
||||
// SetChainResolver wires the handler chain used to refresh stale cache entries.
|
||||
// maxPriority caps which handlers may answer refresh queries (typically
|
||||
// PriorityUpstream, so upstream/default/fallback handlers are consulted and
|
||||
// mgmt/route/local handlers are skipped).
|
||||
func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
|
||||
m.mutex.Lock()
|
||||
m.chain = chain
|
||||
m.chainMaxPriority = maxPriority
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// ServeDNS serves cached A/AAAA records. Stale entries are returned
|
||||
// immediately and refreshed asynchronously (stale-while-revalidate).
|
||||
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 0 {
|
||||
m.continueToNext(w, r)
|
||||
@@ -60,7 +115,14 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
|
||||
m.mutex.RLock()
|
||||
records, found := m.records[question]
|
||||
cached, found := m.records[question]
|
||||
inflight := m.refreshing[question]
|
||||
var shouldRefresh bool
|
||||
if found {
|
||||
stale := time.Since(cached.cachedAt) > m.cacheTTL
|
||||
inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff
|
||||
shouldRefresh = stale && !inBackoff
|
||||
}
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if !found {
|
||||
@@ -68,12 +130,23 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
if inflight != nil && inflight.CompareAndSwap(false, true) {
|
||||
log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)",
|
||||
question.Name)
|
||||
}
|
||||
|
||||
// Skip scheduling a refresh goroutine if one is already inflight for
|
||||
// this question; singleflight would dedup anyway but skipping avoids
|
||||
// a parked goroutine per stale hit under bursty load.
|
||||
if shouldRefresh && inflight == nil {
|
||||
m.scheduleRefresh(question, cached)
|
||||
}
|
||||
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
resp.Authoritative = false
|
||||
resp.RecursionAvailable = true
|
||||
|
||||
resp.Answer = append(resp.Answer, records...)
|
||||
resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt))
|
||||
|
||||
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
|
||||
|
||||
@@ -98,101 +171,260 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
}
|
||||
|
||||
// AddDomain manually adds a domain to cache by resolving it.
|
||||
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
||||
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
||||
// entry for that qtype.
|
||||
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
ips, err := lookupIPWithExtraTimeout(ctx, d)
|
||||
if err != nil {
|
||||
return err
|
||||
aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName)
|
||||
|
||||
if errA != nil && errAAAA != nil {
|
||||
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
|
||||
}
|
||||
|
||||
var aRecords, aaaaRecords []dns.RR
|
||||
for _, ip := range ips {
|
||||
if ip.Is4() {
|
||||
rr := &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dnsName,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300,
|
||||
},
|
||||
A: ip.AsSlice(),
|
||||
}
|
||||
aRecords = append(aRecords, rr)
|
||||
} else if ip.Is6() {
|
||||
rr := &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dnsName,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300,
|
||||
},
|
||||
AAAA: ip.AsSlice(),
|
||||
}
|
||||
aaaaRecords = append(aaaaRecords, rr)
|
||||
if len(aRecords) == 0 && len(aaaaRecords) == 0 {
|
||||
if err := errors.Join(errA, errAAAA); err != nil {
|
||||
return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err)
|
||||
}
|
||||
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if len(aRecords) > 0 {
|
||||
aQuestion := dns.Question{
|
||||
Name: dnsName,
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
m.records[aQuestion] = aRecords
|
||||
}
|
||||
m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now)
|
||||
m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now)
|
||||
|
||||
if len(aaaaRecords) > 0 {
|
||||
aaaaQuestion := dns.Question{
|
||||
Name: dnsName,
|
||||
Qtype: dns.TypeAAAA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
m.records[aaaaQuestion] = aaaaRecords
|
||||
}
|
||||
|
||||
m.mutex.Unlock()
|
||||
|
||||
log.Debugf("added domain=%s with %d A records and %d AAAA records",
|
||||
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
||||
d.SafeString(), len(aRecords), len(aaaaRecords))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
|
||||
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
|
||||
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
|
||||
resultChan := make(chan *ipsResponse, 1)
|
||||
// applyFamilyRecords writes records, evicts on NODATA, leaves the cache
|
||||
// untouched on error. Caller holds m.mutex.
|
||||
func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) {
|
||||
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
||||
switch {
|
||||
case len(records) > 0:
|
||||
m.records[q] = &cachedRecord{records: records, cachedAt: now}
|
||||
case err == nil:
|
||||
delete(m.records, q)
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
|
||||
resultChan <- &ipsResponse{
|
||||
err: err,
|
||||
ips: ips,
|
||||
// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per
|
||||
// unique in-flight key; bursty stale hits share its channel. expected is the
|
||||
// cachedRecord pointer observed by the caller; the refresh only mutates the
|
||||
// cache if that pointer is still the one stored, so a stale in-flight refresh
|
||||
// can't clobber a newer entry written by AddDomain or a competing refresh.
|
||||
func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) {
|
||||
key := question.Name + "|" + dns.TypeToString[question.Qtype]
|
||||
_ = m.refreshGroup.DoChan(key, func() (any, error) {
|
||||
return nil, m.refreshQuestion(question, expected)
|
||||
})
|
||||
}
|
||||
|
||||
// refreshQuestion replaces the cached records on success, or marks the entry
|
||||
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
|
||||
// a resolver loop by spotting a query for this same question arriving on us.
|
||||
// expected pins the cache entry observed at schedule time; mutations only apply
|
||||
// if m.records[question] still points at it.
|
||||
func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
|
||||
if err != nil {
|
||||
m.markRefreshFailed(question, expected)
|
||||
return fmt.Errorf("parse domain: %w", err)
|
||||
}
|
||||
|
||||
records, err := m.lookupRecords(ctx, d, question)
|
||||
if err != nil {
|
||||
fails := m.markRefreshFailed(question, expected)
|
||||
logf := log.Warnf
|
||||
if fails == 0 || fails > 1 {
|
||||
logf = log.Debugf
|
||||
}
|
||||
}()
|
||||
|
||||
var resp *ipsResponse
|
||||
|
||||
select {
|
||||
case <-time.After(dnsTimeout + time.Millisecond*500):
|
||||
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
|
||||
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case resp = <-resultChan:
|
||||
logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype], err, fails)
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.err != nil {
|
||||
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
|
||||
// NOERROR/NODATA: family gone upstream, evict so we stop serving stale.
|
||||
if len(records) == 0 {
|
||||
m.mutex.Lock()
|
||||
if m.records[question] == expected {
|
||||
delete(m.records, question)
|
||||
m.mutex.Unlock()
|
||||
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
return resp.ips, nil
|
||||
|
||||
now := time.Now()
|
||||
m.mutex.Lock()
|
||||
if m.records[question] != expected {
|
||||
m.mutex.Unlock()
|
||||
log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
m.records[question] = &cachedRecord{records: records, cachedAt: now}
|
||||
m.mutex.Unlock()
|
||||
|
||||
log.Infof("refreshed mgmt cache domain=%s type=%s",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Resolver) markRefreshing(question dns.Question) {
|
||||
m.mutex.Lock()
|
||||
m.refreshing[question] = &atomic.Bool{}
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (m *Resolver) clearRefreshing(question dns.Question) {
|
||||
m.mutex.Lock()
|
||||
delete(m.refreshing, question)
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// markRefreshFailed arms the backoff and returns the new consecutive-failure
|
||||
// count so callers can downgrade subsequent failure logs to debug.
|
||||
func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
c, ok := m.records[question]
|
||||
if !ok || c != expected {
|
||||
return 0
|
||||
}
|
||||
c.lastFailedRefresh = time.Now()
|
||||
c.consecFailures++
|
||||
return c.consecFailures
|
||||
}
|
||||
|
||||
// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let
|
||||
// callers tell records, NODATA (nil err, no records), and failure apart.
|
||||
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
|
||||
m.mutex.RLock()
|
||||
chain := m.chain
|
||||
maxPriority := m.chainMaxPriority
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
||||
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
|
||||
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: drop once every supported OS registers a fallback resolver. Safe
|
||||
// today: no root handler at priority ≤ PriorityUpstream means NetBird is
|
||||
// not the system resolver, so net.DefaultResolver will not loop back.
|
||||
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
|
||||
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
|
||||
return
|
||||
}
|
||||
|
||||
// lookupRecords resolves a single record type via chain or OS. The OS branch
|
||||
// arms the loop detector for the duration of its call so that ServeDNS can
|
||||
// spot the OS resolver routing the recursive query back to us.
|
||||
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
|
||||
m.mutex.RLock()
|
||||
chain := m.chain
|
||||
maxPriority := m.chainMaxPriority
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
||||
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
|
||||
}
|
||||
|
||||
// TODO: drop once every supported OS registers a fallback resolver.
|
||||
m.markRefreshing(q)
|
||||
defer m.clearRefreshing(q)
|
||||
|
||||
return m.osLookup(ctx, d, q.Name, q.Qtype)
|
||||
}
|
||||
|
||||
// lookupViaChain resolves via the handler chain and rewrites each RR to use
|
||||
// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache
|
||||
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
|
||||
func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) {
|
||||
msg := &dns.Msg{}
|
||||
msg.SetQuestion(dnsName, qtype)
|
||||
msg.RecursionDesired = true
|
||||
|
||||
resp, err := chain.ResolveInternal(ctx, msg, maxPriority)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("chain resolve: %w", err)
|
||||
}
|
||||
if resp == nil {
|
||||
return nil, fmt.Errorf("chain resolve returned nil response")
|
||||
}
|
||||
if resp.Rcode != dns.RcodeSuccess {
|
||||
return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode])
|
||||
}
|
||||
|
||||
ttl := uint32(m.cacheTTL.Seconds())
|
||||
owners := cnameOwners(dnsName, resp.Answer)
|
||||
var filtered []dns.RR
|
||||
for _, rr := range resp.Answer {
|
||||
h := rr.Header()
|
||||
if h.Class != dns.ClassINET || h.Rrtype != qtype {
|
||||
continue
|
||||
}
|
||||
if !owners[strings.ToLower(dns.Fqdn(h.Name))] {
|
||||
continue
|
||||
}
|
||||
if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil {
|
||||
filtered = append(filtered, cp)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
// osLookup resolves a single family via net.DefaultResolver using resutil,
|
||||
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
|
||||
// returns (nil, nil).
|
||||
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
|
||||
network := resutil.NetworkForQtype(qtype)
|
||||
if network == "" {
|
||||
return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype])
|
||||
}
|
||||
|
||||
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
||||
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
||||
|
||||
result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype)
|
||||
if result.Rcode == dns.RcodeSuccess {
|
||||
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
|
||||
}
|
||||
|
||||
if result.Err != nil {
|
||||
return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err)
|
||||
}
|
||||
return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode])
|
||||
}
|
||||
|
||||
// responseTTL returns the remaining cache lifetime in seconds (rounded up),
|
||||
// so downstream resolvers don't cache an answer for longer than we will.
|
||||
func (m *Resolver) responseTTL(cachedAt time.Time) uint32 {
|
||||
remaining := m.cacheTTL - time.Since(cachedAt)
|
||||
if remaining <= 0 {
|
||||
return 0
|
||||
}
|
||||
return uint32((remaining + time.Second - 1) / time.Second)
|
||||
}
|
||||
|
||||
// PopulateFromConfig extracts and caches domains from the client configuration.
|
||||
@@ -224,19 +456,12 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
aQuestion := dns.Question{
|
||||
Name: dnsName,
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
delete(m.records, aQuestion)
|
||||
|
||||
aaaaQuestion := dns.Question{
|
||||
Name: dnsName,
|
||||
Qtype: dns.TypeAAAA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
delete(m.records, aaaaQuestion)
|
||||
qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}
|
||||
delete(m.records, qA)
|
||||
delete(m.records, qAAAA)
|
||||
delete(m.refreshing, qA)
|
||||
delete(m.refreshing, qAAAA)
|
||||
|
||||
log.Debugf("removed domain=%s from cache", d.SafeString())
|
||||
return nil
|
||||
@@ -394,3 +619,73 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
|
||||
|
||||
return domains
|
||||
}
|
||||
|
||||
// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non
|
||||
// A/AAAA records return nil.
|
||||
func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR {
|
||||
switch r := rr.(type) {
|
||||
case *dns.A:
|
||||
cp := *r
|
||||
cp.Hdr.Name = owner
|
||||
cp.Hdr.Ttl = ttl
|
||||
cp.A = slices.Clone(r.A)
|
||||
return &cp
|
||||
case *dns.AAAA:
|
||||
cp := *r
|
||||
cp.Hdr.Name = owner
|
||||
cp.Hdr.Ttl = ttl
|
||||
cp.AAAA = slices.Clone(r.AAAA)
|
||||
return &cp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cloneRecordsWithTTL clones A/AAAA records preserving their owner and
|
||||
// stamping ttl so the response shares no memory with the cached slice.
|
||||
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
|
||||
out := make([]dns.RR, 0, len(records))
|
||||
for _, rr := range records {
|
||||
if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil {
|
||||
out = append(out, cp)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// cnameOwners returns dnsName plus every target reachable by following CNAMEs
|
||||
// in answer, iterating until fixed point so out-of-order chains resolve.
|
||||
func cnameOwners(dnsName string, answer []dns.RR) map[string]bool {
|
||||
owners := map[string]bool{dnsName: true}
|
||||
for {
|
||||
added := false
|
||||
for _, rr := range answer {
|
||||
cname, ok := rr.(*dns.CNAME)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name := strings.ToLower(dns.Fqdn(cname.Hdr.Name))
|
||||
if !owners[name] {
|
||||
continue
|
||||
}
|
||||
target := strings.ToLower(dns.Fqdn(cname.Target))
|
||||
if !owners[target] {
|
||||
owners[target] = true
|
||||
added = true
|
||||
}
|
||||
}
|
||||
if !added {
|
||||
return owners
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// resolveCacheTTL reads the cache TTL override env var; invalid or empty
|
||||
// values fall back to defaultTTL. Called once per Resolver from NewResolver.
|
||||
func resolveCacheTTL() time.Duration {
|
||||
if v := os.Getenv(envMgmtCacheTTL); v != "" {
|
||||
if d, err := time.ParseDuration(v); err == nil && d > 0 {
|
||||
return d
|
||||
}
|
||||
}
|
||||
return defaultTTL
|
||||
}
|
||||
|
||||
408
client/internal/dns/mgmt/mgmt_refresh_test.go
Normal file
408
client/internal/dns/mgmt/mgmt_refresh_test.go
Normal file
@@ -0,0 +1,408 @@
|
||||
package mgmt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
type fakeChain struct {
|
||||
mu sync.Mutex
|
||||
calls map[string]int
|
||||
answers map[string][]dns.RR
|
||||
err error
|
||||
hasRoot bool
|
||||
onLookup func()
|
||||
}
|
||||
|
||||
func newFakeChain() *fakeChain {
|
||||
return &fakeChain{
|
||||
calls: map[string]int{},
|
||||
answers: map[string][]dns.RR{},
|
||||
hasRoot: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.hasRoot
|
||||
}
|
||||
|
||||
func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) {
|
||||
f.mu.Lock()
|
||||
q := msg.Question[0]
|
||||
key := q.Name + "|" + dns.TypeToString[q.Qtype]
|
||||
f.calls[key]++
|
||||
answers := f.answers[key]
|
||||
err := f.err
|
||||
onLookup := f.onLookup
|
||||
f.mu.Unlock()
|
||||
|
||||
if onLookup != nil {
|
||||
onLookup()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(msg)
|
||||
resp.Answer = answers
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
key := name + "|" + dns.TypeToString[qtype]
|
||||
hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60}
|
||||
switch qtype {
|
||||
case dns.TypeA:
|
||||
f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}}
|
||||
case dns.TypeAAAA:
|
||||
f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeChain) callCount(name string, qtype uint16) int {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.calls[name+"|"+dns.TypeToString[qtype]]
|
||||
}
|
||||
|
||||
// waitFor polls the predicate until it returns true or the deadline passes.
|
||||
func waitFor(t *testing.T, d time.Duration, fn func() bool) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(d)
|
||||
for time.Now().Before(deadline) {
|
||||
if fn() {
|
||||
return
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("condition not met within %s", d)
|
||||
}
|
||||
|
||||
func queryA(t *testing.T, r *Resolver, name string) *dns.Msg {
|
||||
t.Helper()
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(name, dns.TypeA)
|
||||
w := &test.MockResponseWriter{}
|
||||
r.ServeDNS(w, msg)
|
||||
return w.GetLastResponse()
|
||||
}
|
||||
|
||||
func firstA(t *testing.T, resp *dns.Msg) string {
|
||||
t.Helper()
|
||||
require.NotNil(t, resp)
|
||||
require.Greater(t, len(resp.Answer), 0, "expected at least one answer")
|
||||
a, ok := resp.Answer[0].(*dns.A)
|
||||
require.True(t, ok, "expected A record")
|
||||
return a.A.String()
|
||||
}
|
||||
|
||||
func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
|
||||
// Same cached entry age, different cacheTTL values: the shorter TTL must
|
||||
// trigger a background refresh, the longer one must not. Proves that the
|
||||
// per-Resolver cacheTTL field actually drives the stale decision.
|
||||
cachedAt := time.Now().Add(-100 * time.Millisecond)
|
||||
|
||||
newRec := func() *cachedRecord {
|
||||
return &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: cachedAt,
|
||||
}
|
||||
}
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
|
||||
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r.cacheTTL = 10 * time.Millisecond
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
r.records[q] = newRec()
|
||||
|
||||
resp := queryA(t, r, q.Name)
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
|
||||
|
||||
waitFor(t, time.Second, func() bool {
|
||||
return chain.callCount(q.Name, dns.TypeA) >= 1
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r.cacheTTL = time.Hour
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
r.records[q] = newRec()
|
||||
|
||||
resp := queryA(t, r, q.Name)
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp))
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh")
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now(), // fresh
|
||||
}
|
||||
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp))
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh")
|
||||
}
|
||||
|
||||
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now().Add(-2 * defaultTTL), // stale
|
||||
}
|
||||
|
||||
// First query: serves stale immediately.
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
|
||||
|
||||
waitFor(t, time.Second, func() bool {
|
||||
return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1
|
||||
})
|
||||
|
||||
// Next query should now return the refreshed IP.
|
||||
waitFor(t, time.Second, func() bool {
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2"
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
|
||||
var inflight atomic.Int32
|
||||
var maxInflight atomic.Int32
|
||||
chain.onLookup = func() {
|
||||
cur := inflight.Add(1)
|
||||
defer inflight.Add(-1)
|
||||
for {
|
||||
prev := maxInflight.Load()
|
||||
if cur <= prev || maxInflight.CompareAndSwap(prev, cur) {
|
||||
break
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide
|
||||
}
|
||||
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now().Add(-2 * defaultTTL),
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
queryA(t, r, "mgmt.example.com.")
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
waitFor(t, 2*time.Second, func() bool {
|
||||
return inflight.Load() == 0
|
||||
})
|
||||
|
||||
calls := chain.callCount("mgmt.example.com.", dns.TypeA)
|
||||
assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls)
|
||||
assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently")
|
||||
}
|
||||
|
||||
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.err = errors.New("boom")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now().Add(-2 * defaultTTL),
|
||||
}
|
||||
|
||||
// First stale hit triggers a refresh attempt that fails.
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails")
|
||||
|
||||
waitFor(t, time.Second, func() bool {
|
||||
return chain.callCount("mgmt.example.com.", dns.TypeA) == 1
|
||||
})
|
||||
waitFor(t, time.Second, func() bool {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
c, ok := r.records[q]
|
||||
return ok && !c.lastFailedRefresh.IsZero()
|
||||
})
|
||||
|
||||
// Subsequent stale hits within backoff window should not schedule more refreshes.
|
||||
for i := 0; i < 10; i++ {
|
||||
queryA(t, r, "mgmt.example.com.")
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes")
|
||||
}
|
||||
|
||||
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.hasRoot = false
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
// With hasRoot=false the chain must not be consulted. Use a short
|
||||
// deadline so the OS fallback returns quickly without waiting on a
|
||||
// real network call in CI.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
_, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.")
|
||||
|
||||
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA),
|
||||
"chain must not be used when no root handler is registered at the bound priority")
|
||||
}
|
||||
|
||||
func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
|
||||
// ServeDNS being invoked for a question while a refresh for that question
|
||||
// is inflight indicates a resolver loop (OS resolver sent the recursive
|
||||
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
|
||||
r := NewResolver()
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Simulate an inflight refresh.
|
||||
r.markRefreshing(q)
|
||||
defer r.clearRefreshing(q)
|
||||
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries")
|
||||
|
||||
r.mutex.RLock()
|
||||
inflight := r.refreshing[q]
|
||||
r.mutex.RUnlock()
|
||||
require.NotNil(t, inflight)
|
||||
assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed")
|
||||
}
|
||||
|
||||
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now(),
|
||||
}
|
||||
|
||||
r.markRefreshing(q)
|
||||
defer r.clearRefreshing(q)
|
||||
|
||||
// Multiple ServeDNS calls during the same refresh must not re-set the flag
|
||||
// (CompareAndSwap from false -> true returns true only on the first call).
|
||||
for range 5 {
|
||||
queryA(t, r, "mgmt.example.com.")
|
||||
}
|
||||
|
||||
r.mutex.RLock()
|
||||
inflight := r.refreshing[q]
|
||||
r.mutex.RUnlock()
|
||||
assert.True(t, inflight.Load())
|
||||
}
|
||||
|
||||
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
|
||||
r := NewResolver()
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now(),
|
||||
}
|
||||
|
||||
queryA(t, r, "mgmt.example.com.")
|
||||
|
||||
r.mutex.RLock()
|
||||
_, ok := r.refreshing[q]
|
||||
r.mutex.RUnlock()
|
||||
assert.False(t, ok, "no refresh inflight means no loop tracking")
|
||||
}
|
||||
|
||||
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com")))
|
||||
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.2", firstA(t, resp))
|
||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA))
|
||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA))
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -23,6 +24,60 @@ func TestResolver_NewResolver(t *testing.T) {
|
||||
assert.False(t, resolver.MatchSubdomains())
|
||||
}
|
||||
|
||||
func TestResolveCacheTTL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
want time.Duration
|
||||
}{
|
||||
{"unset falls back to default", "", defaultTTL},
|
||||
{"valid duration", "45s", 45 * time.Second},
|
||||
{"valid minutes", "2m", 2 * time.Minute},
|
||||
{"malformed falls back to default", "not-a-duration", defaultTTL},
|
||||
{"zero falls back to default", "0s", defaultTTL},
|
||||
{"negative falls back to default", "-5s", defaultTTL},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(envMgmtCacheTTL, tc.value)
|
||||
got := resolveCacheTTL()
|
||||
assert.Equal(t, tc.want, got, "parsed TTL should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
|
||||
t.Setenv(envMgmtCacheTTL, "7s")
|
||||
r := NewResolver()
|
||||
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
|
||||
}
|
||||
|
||||
func TestResolver_ResponseTTL(t *testing.T) {
|
||||
now := time.Now()
|
||||
tests := []struct {
|
||||
name string
|
||||
cacheTTL time.Duration
|
||||
cachedAt time.Time
|
||||
wantMin uint32
|
||||
wantMax uint32
|
||||
}{
|
||||
{"fresh entry returns full TTL", 60 * time.Second, now, 59, 60},
|
||||
{"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31},
|
||||
{"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0},
|
||||
{"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := &Resolver{cacheTTL: tc.cacheTTL}
|
||||
got := r.responseTTL(tc.cachedAt)
|
||||
assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin")
|
||||
assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolver_ExtractDomainFromURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -212,6 +212,7 @@ func newDefaultServer(
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
|
||||
mgmtCacheResolver := mgmt.NewResolver()
|
||||
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
|
||||
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
|
||||
@@ -26,7 +26,9 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
@@ -67,6 +69,7 @@ import (
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
sProto "github.com/netbirdio/netbird/shared/signal/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/util/capture"
|
||||
)
|
||||
|
||||
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
|
||||
@@ -217,6 +220,8 @@ type Engine struct {
|
||||
portForwardManager *portforward.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
|
||||
afpacketCapture *capture.AFPacketCapture
|
||||
|
||||
// Sync response persistence (protected by syncRespMux)
|
||||
syncRespMux sync.RWMutex
|
||||
persistSyncResponse bool
|
||||
@@ -570,7 +575,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.connMgr.Start(e.ctx)
|
||||
|
||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||
e.srWatcher.Start()
|
||||
e.srWatcher.Start(peer.IsForceRelayed())
|
||||
|
||||
e.receiveSignalEvents()
|
||||
e.receiveManagementEvents()
|
||||
@@ -604,6 +609,8 @@ func (e *Engine) createFirewall() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
firewalld.SetParentContext(e.ctx)
|
||||
|
||||
var err error
|
||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
||||
if err != nil {
|
||||
@@ -1695,6 +1702,11 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
}
|
||||
|
||||
func (e *Engine) close() {
|
||||
if e.afpacketCapture != nil {
|
||||
e.afpacketCapture.Stop()
|
||||
e.afpacketCapture = nil
|
||||
}
|
||||
|
||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||
|
||||
if e.wgInterface != nil {
|
||||
@@ -2160,6 +2172,62 @@ func (e *Engine) Address() (netip.Addr, error) {
|
||||
return e.wgInterface.Address().IP, nil
|
||||
}
|
||||
|
||||
// SetCapture sets or clears packet capture on the WireGuard device.
|
||||
// On userspace WireGuard, it taps the FilteredDevice directly.
|
||||
// On kernel WireGuard (Linux), it falls back to AF_PACKET raw socket capture.
|
||||
// Pass nil to disable capture.
|
||||
func (e *Engine) SetCapture(pc device.PacketCapture) error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
intf := e.wgInterface
|
||||
if intf == nil {
|
||||
return errors.New("wireguard interface not initialized")
|
||||
}
|
||||
|
||||
if e.afpacketCapture != nil {
|
||||
e.afpacketCapture.Stop()
|
||||
e.afpacketCapture = nil
|
||||
}
|
||||
|
||||
dev := intf.GetDevice()
|
||||
if dev != nil {
|
||||
dev.SetCapture(pc)
|
||||
e.setForwarderCapture(pc)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Kernel mode: no FilteredDevice. Use AF_PACKET on Linux.
|
||||
if pc == nil {
|
||||
return nil
|
||||
}
|
||||
sess, ok := pc.(*capture.Session)
|
||||
if !ok {
|
||||
return errors.New("filtered device not available and AF_PACKET requires *capture.Session")
|
||||
}
|
||||
|
||||
afc := capture.NewAFPacketCapture(intf.Name(), sess)
|
||||
if err := afc.Start(); err != nil {
|
||||
return fmt.Errorf("start AF_PACKET capture on %s: %w", intf.Name(), err)
|
||||
}
|
||||
e.afpacketCapture = afc
|
||||
return nil
|
||||
}
|
||||
|
||||
// setForwarderCapture propagates capture to the USP filter's forwarder endpoint.
|
||||
// This captures outbound response packets that bypass the FilteredDevice in netstack mode.
|
||||
func (e *Engine) setForwarderCapture(pc device.PacketCapture) {
|
||||
if e.firewall == nil {
|
||||
return
|
||||
}
|
||||
type forwarderCapturer interface {
|
||||
SetPacketCapture(pc forwarder.PacketCapture)
|
||||
}
|
||||
if fc, ok := e.firewall.(forwarderCapturer); ok {
|
||||
fc.SetPacketCapture(pc)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {
|
||||
if e.firewall == nil {
|
||||
log.Warn("firewall is disabled, not updating forwarding rules")
|
||||
|
||||
@@ -185,17 +185,20 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
|
||||
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
|
||||
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
return err
|
||||
forceRelay := IsForceRelayed()
|
||||
if !forceRelay {
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
|
||||
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages)
|
||||
|
||||
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
|
||||
if !isForceRelayed() {
|
||||
if !forceRelay {
|
||||
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
|
||||
}
|
||||
|
||||
@@ -251,7 +254,9 @@ func (conn *Conn) Close(signalToRemote bool) {
|
||||
conn.wgWatcherCancel()
|
||||
}
|
||||
conn.workerRelay.CloseConn()
|
||||
conn.workerICE.Close()
|
||||
if conn.workerICE != nil {
|
||||
conn.workerICE.Close()
|
||||
}
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
err := conn.wgProxyRelay.CloseConn()
|
||||
@@ -294,7 +299,9 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) {
|
||||
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
||||
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
|
||||
conn.dumpState.RemoteCandidate()
|
||||
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
||||
if conn.workerICE != nil {
|
||||
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
||||
}
|
||||
}
|
||||
|
||||
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
|
||||
@@ -712,33 +719,35 @@ func (conn *Conn) evalStatus() ConnStatus {
|
||||
return StatusConnecting
|
||||
}
|
||||
|
||||
func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||
// would be better to protect this with a mutex, but it could cause deadlock with Close function
|
||||
|
||||
// isConnectedOnAllWay evaluates the overall connection status based on ICE and Relay transports.
|
||||
//
|
||||
// The result is a tri-state:
|
||||
// - ConnStatusConnected: all available transports are up
|
||||
// - ConnStatusPartiallyConnected: relay is up but ICE is still pending/reconnecting
|
||||
// - ConnStatusDisconnected: no working transport
|
||||
func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) {
|
||||
defer func() {
|
||||
if !connected {
|
||||
if status == guard.ConnStatusDisconnected {
|
||||
conn.logTraceConnState()
|
||||
}
|
||||
}()
|
||||
|
||||
// For JS platform: only relay connection is supported
|
||||
if runtime.GOOS == "js" {
|
||||
return conn.statusRelay.Get() == worker.StatusConnected
|
||||
iceWorkerCreated := conn.workerICE != nil
|
||||
|
||||
var iceInProgress bool
|
||||
if iceWorkerCreated {
|
||||
iceInProgress = conn.workerICE.InProgress()
|
||||
}
|
||||
|
||||
// For non-JS platforms: check ICE connection status
|
||||
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
||||
return false
|
||||
}
|
||||
|
||||
// If relay is supported with peer, it must also be connected
|
||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||
if conn.statusRelay.Get() == worker.StatusDisconnected {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
return evalConnStatus(connStatusInputs{
|
||||
forceRelay: IsForceRelayed(),
|
||||
peerUsesRelay: conn.workerRelay.IsRelayConnectionSupportedWithPeer(),
|
||||
relayConnected: conn.statusRelay.Get() == worker.StatusConnected,
|
||||
remoteSupportsICE: conn.handshaker.RemoteICESupported(),
|
||||
iceWorkerCreated: iceWorkerCreated,
|
||||
iceStatusConnecting: conn.statusICE.Get() != worker.StatusDisconnected,
|
||||
iceInProgress: iceInProgress,
|
||||
})
|
||||
}
|
||||
|
||||
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
|
||||
@@ -926,3 +935,43 @@ func isController(config ConnConfig) bool {
|
||||
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
||||
return remoteRosenpassPubKey != nil
|
||||
}
|
||||
|
||||
func evalConnStatus(in connStatusInputs) guard.ConnStatus {
|
||||
// "Relay up and needed" — the peer uses relay and the transport is connected.
|
||||
relayUsedAndUp := in.peerUsesRelay && in.relayConnected
|
||||
|
||||
// Force-relay mode: ICE never runs. Relay is the only transport and must be up.
|
||||
if in.forceRelay {
|
||||
return boolToConnStatus(relayUsedAndUp)
|
||||
}
|
||||
|
||||
// Remote peer doesn't support ICE, or we haven't created the worker yet:
|
||||
// relay is the only possible transport.
|
||||
if !in.remoteSupportsICE || !in.iceWorkerCreated {
|
||||
return boolToConnStatus(relayUsedAndUp)
|
||||
}
|
||||
|
||||
// ICE counts as "up" when the status is anything other than Disconnected, OR
|
||||
// when a negotiation is currently in progress (so we don't spam offers while one is in flight).
|
||||
iceUp := in.iceStatusConnecting || in.iceInProgress
|
||||
|
||||
// Relay side is acceptable if the peer doesn't rely on relay, or relay is connected.
|
||||
relayOK := !in.peerUsesRelay || in.relayConnected
|
||||
|
||||
switch {
|
||||
case iceUp && relayOK:
|
||||
return guard.ConnStatusConnected
|
||||
case relayUsedAndUp:
|
||||
// Relay is up but ICE is down — partially connected.
|
||||
return guard.ConnStatusPartiallyConnected
|
||||
default:
|
||||
return guard.ConnStatusDisconnected
|
||||
}
|
||||
}
|
||||
|
||||
func boolToConnStatus(connected bool) guard.ConnStatus {
|
||||
if connected {
|
||||
return guard.ConnStatusConnected
|
||||
}
|
||||
return guard.ConnStatusDisconnected
|
||||
}
|
||||
|
||||
@@ -13,6 +13,20 @@ const (
|
||||
StatusConnected
|
||||
)
|
||||
|
||||
// connStatusInputs is the primitive-valued snapshot of the state that drives the
|
||||
// tri-state connection classification. Extracted so the decision logic can be unit-tested
|
||||
// without constructing full Worker/Handshaker objects.
|
||||
type connStatusInputs struct {
|
||||
forceRelay bool // NB_FORCE_RELAY or JS/WASM
|
||||
peerUsesRelay bool // remote peer advertises relay support AND local has relay
|
||||
relayConnected bool // statusRelay reports Connected (independent of whether peer uses relay)
|
||||
remoteSupportsICE bool // remote peer sent ICE credentials
|
||||
iceWorkerCreated bool // local WorkerICE exists (false in force-relay mode)
|
||||
iceStatusConnecting bool // statusICE is anything other than Disconnected
|
||||
iceInProgress bool // a negotiation is currently in flight
|
||||
}
|
||||
|
||||
|
||||
// ConnStatus describe the status of a peer's connection
|
||||
type ConnStatus int32
|
||||
|
||||
|
||||
201
client/internal/peer/conn_status_eval_test.go
Normal file
201
client/internal/peer/conn_status_eval_test.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
)
|
||||
|
||||
func TestEvalConnStatus_ForceRelay(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in connStatusInputs
|
||||
want guard.ConnStatus
|
||||
}{
|
||||
{
|
||||
name: "force relay, peer uses relay, relay up",
|
||||
in: connStatusInputs{
|
||||
forceRelay: true,
|
||||
peerUsesRelay: true,
|
||||
relayConnected: true,
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "force relay, peer uses relay, relay down",
|
||||
in: connStatusInputs{
|
||||
forceRelay: true,
|
||||
peerUsesRelay: true,
|
||||
relayConnected: false,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "force relay, peer does NOT use relay - disconnected forever",
|
||||
in: connStatusInputs{
|
||||
forceRelay: true,
|
||||
peerUsesRelay: false,
|
||||
relayConnected: true,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := evalConnStatus(tc.in); got != tc.want {
|
||||
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvalConnStatus_ICEUnavailable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in connStatusInputs
|
||||
want guard.ConnStatus
|
||||
}{
|
||||
{
|
||||
name: "remote does not support ICE, peer uses relay, relay up",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: true,
|
||||
relayConnected: true,
|
||||
remoteSupportsICE: false,
|
||||
iceWorkerCreated: true,
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "remote does not support ICE, peer uses relay, relay down",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: true,
|
||||
relayConnected: false,
|
||||
remoteSupportsICE: false,
|
||||
iceWorkerCreated: true,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "ICE worker not yet created, relay up",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: true,
|
||||
relayConnected: true,
|
||||
remoteSupportsICE: true,
|
||||
iceWorkerCreated: false,
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "remote does not support ICE, peer does not use relay",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: false,
|
||||
relayConnected: false,
|
||||
remoteSupportsICE: false,
|
||||
iceWorkerCreated: true,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := evalConnStatus(tc.in); got != tc.want {
|
||||
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvalConnStatus_FullyAvailable(t *testing.T) {
|
||||
base := connStatusInputs{
|
||||
remoteSupportsICE: true,
|
||||
iceWorkerCreated: true,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutator func(*connStatusInputs)
|
||||
want guard.ConnStatus
|
||||
}{
|
||||
{
|
||||
name: "ICE connected, relay connected, peer uses relay",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = true
|
||||
in.relayConnected = true
|
||||
in.iceStatusConnecting = true
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE connected, peer does NOT use relay",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.relayConnected = false
|
||||
in.iceStatusConnecting = true
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE InProgress only, peer does NOT use relay",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = true
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE down, relay up, peer uses relay -> partial",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = true
|
||||
in.relayConnected = true
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = false
|
||||
},
|
||||
want: guard.ConnStatusPartiallyConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE down, peer does NOT use relay -> disconnected",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.relayConnected = false
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = false
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "ICE up, peer uses relay but relay down -> partial (relay required, ICE ignored)",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = true
|
||||
in.relayConnected = false
|
||||
in.iceStatusConnecting = true
|
||||
},
|
||||
// relayOK = false (peer uses relay but it's down), iceUp = true
|
||||
// first switch arm fails (relayOK false), relayUsedAndUp = false (relay down),
|
||||
// falls into default: Disconnected.
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "ICE down, relay up but peer does not use relay -> disconnected",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.relayConnected = true // not actually used since peer doesn't rely on it
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = false
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
in := base
|
||||
tc.mutator(&in)
|
||||
if got := evalConnStatus(in); got != tc.want {
|
||||
t.Fatalf("evalConnStatus = %v, want %v (inputs: %+v)", got, tc.want, in)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,7 @@ const (
|
||||
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
|
||||
)
|
||||
|
||||
func isForceRelayed() bool {
|
||||
func IsForceRelayed() bool {
|
||||
if runtime.GOOS == "js" {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -8,7 +8,19 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type isConnectedFunc func() bool
|
||||
// ConnStatus represents the connection state as seen by the guard.
|
||||
type ConnStatus int
|
||||
|
||||
const (
|
||||
// ConnStatusDisconnected means neither ICE nor Relay is connected.
|
||||
ConnStatusDisconnected ConnStatus = iota
|
||||
// ConnStatusPartiallyConnected means Relay is connected but ICE is not.
|
||||
ConnStatusPartiallyConnected
|
||||
// ConnStatusConnected means all required connections are established.
|
||||
ConnStatusConnected
|
||||
)
|
||||
|
||||
type connStatusFunc func() ConnStatus
|
||||
|
||||
// Guard is responsible for the reconnection logic.
|
||||
// It will trigger to send an offer to the peer then has connection issues.
|
||||
@@ -20,14 +32,14 @@ type isConnectedFunc func() bool
|
||||
// - ICE candidate changes
|
||||
type Guard struct {
|
||||
log *log.Entry
|
||||
isConnectedOnAllWay isConnectedFunc
|
||||
isConnectedOnAllWay connStatusFunc
|
||||
timeout time.Duration
|
||||
srWatcher *SRWatcher
|
||||
relayedConnDisconnected chan struct{}
|
||||
iCEConnDisconnected chan struct{}
|
||||
}
|
||||
|
||||
func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||
func NewGuard(log *log.Entry, isConnectedFn connStatusFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||
return &Guard{
|
||||
log: log,
|
||||
isConnectedOnAllWay: isConnectedFn,
|
||||
@@ -57,8 +69,17 @@ func (g *Guard) SetICEConnDisconnected() {
|
||||
}
|
||||
}
|
||||
|
||||
// reconnectLoopWithRetry periodically check the connection status.
|
||||
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
|
||||
// reconnectLoopWithRetry periodically checks the connection status and sends offers to re-establish connectivity.
|
||||
//
|
||||
// Behavior depends on the connection state reported by isConnectedOnAllWay:
|
||||
// - Connected: no action, the peer is fully reachable.
|
||||
// - Disconnected (neither ICE nor Relay): retries aggressively with exponential backoff (800ms doubling
|
||||
// up to timeout), never gives up. This ensures rapid recovery when the peer has no connectivity at all.
|
||||
// - PartiallyConnected (Relay up, ICE not): retries up to 3 times with exponential backoff, then switches
|
||||
// to one attempt per hour. This limits signaling traffic when relay already provides connectivity.
|
||||
//
|
||||
// External events (relay/ICE disconnect, signal/relay reconnect, candidate changes) reset the retry
|
||||
// counter and backoff ticker, giving ICE a fresh chance after network conditions change.
|
||||
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
||||
srReconnectedChan := g.srWatcher.NewListener()
|
||||
defer g.srWatcher.RemoveListener(srReconnectedChan)
|
||||
@@ -68,36 +89,47 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
||||
|
||||
tickerChannel := ticker.C
|
||||
|
||||
iceState := &iceRetryState{log: g.log}
|
||||
defer iceState.reset()
|
||||
|
||||
for {
|
||||
select {
|
||||
case t := <-tickerChannel:
|
||||
if t.IsZero() {
|
||||
g.log.Infof("retry timed out, stop periodic offer sending")
|
||||
// after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop
|
||||
tickerChannel = make(<-chan time.Time)
|
||||
continue
|
||||
case <-tickerChannel:
|
||||
switch g.isConnectedOnAllWay() {
|
||||
case ConnStatusConnected:
|
||||
// all good, nothing to do
|
||||
case ConnStatusDisconnected:
|
||||
callback()
|
||||
case ConnStatusPartiallyConnected:
|
||||
if iceState.shouldRetry() {
|
||||
callback()
|
||||
} else {
|
||||
iceState.enterHourlyMode()
|
||||
ticker.Stop()
|
||||
tickerChannel = iceState.hourlyC()
|
||||
}
|
||||
}
|
||||
|
||||
if !g.isConnectedOnAllWay() {
|
||||
callback()
|
||||
}
|
||||
case <-g.relayedConnDisconnected:
|
||||
g.log.Debugf("Relay connection changed, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
ticker = g.newReconnectTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
iceState.reset()
|
||||
|
||||
case <-g.iCEConnDisconnected:
|
||||
g.log.Debugf("ICE connection changed, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
ticker = g.newReconnectTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
iceState.reset()
|
||||
|
||||
case <-srReconnectedChan:
|
||||
g.log.Debugf("has network changes, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
ticker = g.newReconnectTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
iceState.reset()
|
||||
|
||||
case <-ctx.Done():
|
||||
g.log.Debugf("context is done, stop reconnect loop")
|
||||
@@ -120,7 +152,7 @@ func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker {
|
||||
return backoff.NewTicker(bo)
|
||||
}
|
||||
|
||||
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
|
||||
func (g *Guard) newReconnectTicker(ctx context.Context) *backoff.Ticker {
|
||||
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 800 * time.Millisecond,
|
||||
RandomizationFactor: 0.1,
|
||||
|
||||
61
client/internal/peer/guard/ice_retry_state.go
Normal file
61
client/internal/peer/guard/ice_retry_state.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package guard
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxICERetries is the maximum number of ICE offer attempts when relay is connected
|
||||
maxICERetries = 3
|
||||
// iceRetryInterval is the periodic retry interval after ICE retries are exhausted
|
||||
iceRetryInterval = 1 * time.Hour
|
||||
)
|
||||
|
||||
// iceRetryState tracks the limited ICE retry attempts when relay is already connected.
|
||||
// After maxICERetries attempts it switches to a periodic hourly retry.
|
||||
type iceRetryState struct {
|
||||
log *log.Entry
|
||||
retries int
|
||||
hourly *time.Ticker
|
||||
}
|
||||
|
||||
func (s *iceRetryState) reset() {
|
||||
s.retries = 0
|
||||
if s.hourly != nil {
|
||||
s.hourly.Stop()
|
||||
s.hourly = nil
|
||||
}
|
||||
}
|
||||
|
||||
// shouldRetry reports whether the caller should send another ICE offer on this tick.
|
||||
// Returns false when the per-cycle retry budget is exhausted and the caller must switch
|
||||
// to the hourly ticker via enterHourlyMode + hourlyC.
|
||||
func (s *iceRetryState) shouldRetry() bool {
|
||||
if s.hourly != nil {
|
||||
s.log.Debugf("hourly ICE retry attempt")
|
||||
return true
|
||||
}
|
||||
|
||||
s.retries++
|
||||
if s.retries <= maxICERetries {
|
||||
s.log.Debugf("ICE retry attempt %d/%d", s.retries, maxICERetries)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// enterHourlyMode starts the hourly retry ticker. Must be called after shouldRetry returns false.
|
||||
func (s *iceRetryState) enterHourlyMode() {
|
||||
s.log.Infof("ICE retries exhausted (%d/%d), switching to hourly retry", maxICERetries, maxICERetries)
|
||||
s.hourly = time.NewTicker(iceRetryInterval)
|
||||
}
|
||||
|
||||
func (s *iceRetryState) hourlyC() <-chan time.Time {
|
||||
if s.hourly == nil {
|
||||
return nil
|
||||
}
|
||||
return s.hourly.C
|
||||
}
|
||||
103
client/internal/peer/guard/ice_retry_state_test.go
Normal file
103
client/internal/peer/guard/ice_retry_state_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package guard
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func newTestRetryState() *iceRetryState {
|
||||
return &iceRetryState{log: log.NewEntry(log.StandardLogger())}
|
||||
}
|
||||
|
||||
func TestICERetryState_AllowsInitialBudget(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
|
||||
for i := 1; i <= maxICERetries; i++ {
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned false on attempt %d, want true (budget = %d)", i, maxICERetries)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ExhaustsAfterBudget(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
|
||||
for i := 0; i < maxICERetries; i++ {
|
||||
_ = s.shouldRetry()
|
||||
}
|
||||
|
||||
if s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned true after budget exhausted, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_HourlyCNilBeforeEnterHourlyMode(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
|
||||
if s.hourlyC() != nil {
|
||||
t.Fatalf("hourlyC returned non-nil channel before enterHourlyMode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_EnterHourlyModeArmsTicker(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
for i := 0; i < maxICERetries+1; i++ {
|
||||
_ = s.shouldRetry()
|
||||
}
|
||||
|
||||
s.enterHourlyMode()
|
||||
defer s.reset()
|
||||
|
||||
if s.hourlyC() == nil {
|
||||
t.Fatalf("hourlyC returned nil after enterHourlyMode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ShouldRetryTrueInHourlyMode(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
s.enterHourlyMode()
|
||||
defer s.reset()
|
||||
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned false in hourly mode, want true")
|
||||
}
|
||||
|
||||
// Subsequent calls also return true — we keep retrying on each hourly tick.
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("second shouldRetry returned false in hourly mode, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ResetRestoresBudget(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
for i := 0; i < maxICERetries+1; i++ {
|
||||
_ = s.shouldRetry()
|
||||
}
|
||||
s.enterHourlyMode()
|
||||
|
||||
s.reset()
|
||||
|
||||
if s.hourlyC() != nil {
|
||||
t.Fatalf("hourlyC returned non-nil channel after reset")
|
||||
}
|
||||
if s.retries != 0 {
|
||||
t.Fatalf("retries = %d after reset, want 0", s.retries)
|
||||
}
|
||||
|
||||
for i := 1; i <= maxICERetries; i++ {
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned false on attempt %d after reset, want true", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ResetIsIdempotent(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
s.reset()
|
||||
s.reset() // second call must not panic or re-stop a nil ticker
|
||||
|
||||
if s.hourlyC() != nil {
|
||||
t.Fatalf("hourlyC non-nil after double reset")
|
||||
}
|
||||
}
|
||||
@@ -39,7 +39,7 @@ func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscove
|
||||
return srw
|
||||
}
|
||||
|
||||
func (w *SRWatcher) Start() {
|
||||
func (w *SRWatcher) Start(disableICEMonitor bool) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
@@ -50,8 +50,10 @@ func (w *SRWatcher) Start() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
w.cancelIceMonitor = cancel
|
||||
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||
if !disableICEMonitor {
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||
}
|
||||
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
||||
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -43,6 +44,10 @@ type OfferAnswer struct {
|
||||
SessionID *ICESessionID
|
||||
}
|
||||
|
||||
func (o *OfferAnswer) hasICECredentials() bool {
|
||||
return o.IceCredentials.UFrag != "" && o.IceCredentials.Pwd != ""
|
||||
}
|
||||
|
||||
type Handshaker struct {
|
||||
mu sync.Mutex
|
||||
log *log.Entry
|
||||
@@ -59,6 +64,10 @@ type Handshaker struct {
|
||||
relayListener *AsyncOfferListener
|
||||
iceListener func(remoteOfferAnswer *OfferAnswer)
|
||||
|
||||
// remoteICESupported tracks whether the remote peer includes ICE credentials in its offers/answers.
|
||||
// When false, the local side skips ICE listener dispatch and suppresses ICE credentials in responses.
|
||||
remoteICESupported atomic.Bool
|
||||
|
||||
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
||||
remoteOffersCh chan OfferAnswer
|
||||
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
|
||||
@@ -66,7 +75,7 @@ type Handshaker struct {
|
||||
}
|
||||
|
||||
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker {
|
||||
return &Handshaker{
|
||||
h := &Handshaker{
|
||||
log: log,
|
||||
config: config,
|
||||
signaler: signaler,
|
||||
@@ -76,6 +85,13 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
|
||||
remoteOffersCh: make(chan OfferAnswer),
|
||||
remoteAnswerCh: make(chan OfferAnswer),
|
||||
}
|
||||
// assume remote supports ICE until we learn otherwise from received offers
|
||||
h.remoteICESupported.Store(ice != nil)
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *Handshaker) RemoteICESupported() bool {
|
||||
return h.remoteICESupported.Load()
|
||||
}
|
||||
|
||||
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||
@@ -90,18 +106,20 @@ func (h *Handshaker) Listen(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case remoteOfferAnswer := <-h.remoteOffersCh:
|
||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
|
||||
|
||||
// Record signaling received for reconnection attempts
|
||||
if h.metricsStages != nil {
|
||||
h.metricsStages.RecordSignalingReceived()
|
||||
}
|
||||
|
||||
h.updateRemoteICEState(&remoteOfferAnswer)
|
||||
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if h.iceListener != nil {
|
||||
if h.iceListener != nil && h.RemoteICESupported() {
|
||||
h.iceListener(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
@@ -110,18 +128,20 @@ func (h *Handshaker) Listen(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
|
||||
|
||||
// Record signaling received for reconnection attempts
|
||||
if h.metricsStages != nil {
|
||||
h.metricsStages.RecordSignalingReceived()
|
||||
}
|
||||
|
||||
h.updateRemoteICEState(&remoteOfferAnswer)
|
||||
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if h.iceListener != nil {
|
||||
if h.iceListener != nil && h.RemoteICESupported() {
|
||||
h.iceListener(&remoteOfferAnswer)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
@@ -183,15 +203,18 @@ func (h *Handshaker) sendAnswer() error {
|
||||
}
|
||||
|
||||
func (h *Handshaker) buildOfferAnswer() OfferAnswer {
|
||||
uFrag, pwd := h.ice.GetLocalUserCredentials()
|
||||
sid := h.ice.SessionID()
|
||||
answer := OfferAnswer{
|
||||
IceCredentials: IceCredentials{uFrag, pwd},
|
||||
WgListenPort: h.config.LocalWgPort,
|
||||
Version: version.NetbirdVersion(),
|
||||
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
|
||||
RosenpassAddr: h.config.RosenpassConfig.Addr,
|
||||
SessionID: &sid,
|
||||
}
|
||||
|
||||
if h.ice != nil && h.RemoteICESupported() {
|
||||
uFrag, pwd := h.ice.GetLocalUserCredentials()
|
||||
sid := h.ice.SessionID()
|
||||
answer.IceCredentials = IceCredentials{uFrag, pwd}
|
||||
answer.SessionID = &sid
|
||||
}
|
||||
|
||||
if addr, err := h.relay.RelayInstanceAddress(); err == nil {
|
||||
@@ -200,3 +223,18 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer {
|
||||
|
||||
return answer
|
||||
}
|
||||
|
||||
func (h *Handshaker) updateRemoteICEState(offer *OfferAnswer) {
|
||||
hasICE := offer.hasICECredentials()
|
||||
prev := h.remoteICESupported.Swap(hasICE)
|
||||
if prev != hasICE {
|
||||
if hasICE {
|
||||
h.log.Infof("remote peer started sending ICE credentials")
|
||||
} else {
|
||||
h.log.Infof("remote peer stopped sending ICE credentials")
|
||||
if h.ice != nil {
|
||||
h.ice.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,9 +46,13 @@ func (s *Signaler) Ready() bool {
|
||||
|
||||
// SignalOfferAnswer signals either an offer or an answer to remote peer
|
||||
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
|
||||
sessionIDBytes, err := offerAnswer.SessionID.Bytes()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get session ID bytes: %v", err)
|
||||
var sessionIDBytes []byte
|
||||
if offerAnswer.SessionID != nil {
|
||||
var err error
|
||||
sessionIDBytes, err = offerAnswer.SessionID.Bytes()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get session ID bytes: %v", err)
|
||||
}
|
||||
}
|
||||
msg, err := signal.MarshalCredential(
|
||||
s.wgPrivateKey,
|
||||
|
||||
@@ -5977,6 +5977,288 @@ func (x *ExposeServiceReady) GetPortAutoAssigned() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type StartCaptureRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
TextOutput bool `protobuf:"varint,1,opt,name=text_output,json=textOutput,proto3" json:"text_output,omitempty"`
|
||||
SnapLen uint32 `protobuf:"varint,2,opt,name=snap_len,json=snapLen,proto3" json:"snap_len,omitempty"`
|
||||
Duration *durationpb.Duration `protobuf:"bytes,3,opt,name=duration,proto3" json:"duration,omitempty"`
|
||||
FilterExpr string `protobuf:"bytes,4,opt,name=filter_expr,json=filterExpr,proto3" json:"filter_expr,omitempty"`
|
||||
Verbose bool `protobuf:"varint,5,opt,name=verbose,proto3" json:"verbose,omitempty"`
|
||||
Ascii bool `protobuf:"varint,6,opt,name=ascii,proto3" json:"ascii,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *StartCaptureRequest) Reset() {
|
||||
*x = StartCaptureRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[90]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *StartCaptureRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*StartCaptureRequest) ProtoMessage() {}
|
||||
|
||||
func (x *StartCaptureRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[90]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use StartCaptureRequest.ProtoReflect.Descriptor instead.
|
||||
func (*StartCaptureRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{90}
|
||||
}
|
||||
|
||||
func (x *StartCaptureRequest) GetTextOutput() bool {
|
||||
if x != nil {
|
||||
return x.TextOutput
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *StartCaptureRequest) GetSnapLen() uint32 {
|
||||
if x != nil {
|
||||
return x.SnapLen
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *StartCaptureRequest) GetDuration() *durationpb.Duration {
|
||||
if x != nil {
|
||||
return x.Duration
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *StartCaptureRequest) GetFilterExpr() string {
|
||||
if x != nil {
|
||||
return x.FilterExpr
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *StartCaptureRequest) GetVerbose() bool {
|
||||
if x != nil {
|
||||
return x.Verbose
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *StartCaptureRequest) GetAscii() bool {
|
||||
if x != nil {
|
||||
return x.Ascii
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type CapturePacket struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *CapturePacket) Reset() {
|
||||
*x = CapturePacket{}
|
||||
mi := &file_daemon_proto_msgTypes[91]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *CapturePacket) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*CapturePacket) ProtoMessage() {}
|
||||
|
||||
func (x *CapturePacket) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[91]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use CapturePacket.ProtoReflect.Descriptor instead.
|
||||
func (*CapturePacket) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{91}
|
||||
}
|
||||
|
||||
func (x *CapturePacket) GetData() []byte {
|
||||
if x != nil {
|
||||
return x.Data
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type StartBundleCaptureRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// timeout auto-stops the capture after this duration.
|
||||
// Clamped to a server-side maximum (10 minutes). Zero or unset defaults to the maximum.
|
||||
Timeout *durationpb.Duration `protobuf:"bytes,1,opt,name=timeout,proto3" json:"timeout,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *StartBundleCaptureRequest) Reset() {
|
||||
*x = StartBundleCaptureRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[92]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *StartBundleCaptureRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*StartBundleCaptureRequest) ProtoMessage() {}
|
||||
|
||||
func (x *StartBundleCaptureRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[92]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use StartBundleCaptureRequest.ProtoReflect.Descriptor instead.
|
||||
func (*StartBundleCaptureRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{92}
|
||||
}
|
||||
|
||||
func (x *StartBundleCaptureRequest) GetTimeout() *durationpb.Duration {
|
||||
if x != nil {
|
||||
return x.Timeout
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type StartBundleCaptureResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *StartBundleCaptureResponse) Reset() {
|
||||
*x = StartBundleCaptureResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[93]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *StartBundleCaptureResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*StartBundleCaptureResponse) ProtoMessage() {}
|
||||
|
||||
func (x *StartBundleCaptureResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[93]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use StartBundleCaptureResponse.ProtoReflect.Descriptor instead.
|
||||
func (*StartBundleCaptureResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{93}
|
||||
}
|
||||
|
||||
type StopBundleCaptureRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *StopBundleCaptureRequest) Reset() {
|
||||
*x = StopBundleCaptureRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[94]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *StopBundleCaptureRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*StopBundleCaptureRequest) ProtoMessage() {}
|
||||
|
||||
func (x *StopBundleCaptureRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[94]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use StopBundleCaptureRequest.ProtoReflect.Descriptor instead.
|
||||
func (*StopBundleCaptureRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{94}
|
||||
}
|
||||
|
||||
type StopBundleCaptureResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *StopBundleCaptureResponse) Reset() {
|
||||
*x = StopBundleCaptureResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[95]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *StopBundleCaptureResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*StopBundleCaptureResponse) ProtoMessage() {}
|
||||
|
||||
func (x *StopBundleCaptureResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[95]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use StopBundleCaptureResponse.ProtoReflect.Descriptor instead.
|
||||
func (*StopBundleCaptureResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{95}
|
||||
}
|
||||
|
||||
type PortInfo_Range struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
|
||||
@@ -5987,7 +6269,7 @@ type PortInfo_Range struct {
|
||||
|
||||
func (x *PortInfo_Range) Reset() {
|
||||
*x = PortInfo_Range{}
|
||||
mi := &file_daemon_proto_msgTypes[91]
|
||||
mi := &file_daemon_proto_msgTypes[97]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5999,7 +6281,7 @@ func (x *PortInfo_Range) String() string {
|
||||
func (*PortInfo_Range) ProtoMessage() {}
|
||||
|
||||
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[91]
|
||||
mi := &file_daemon_proto_msgTypes[97]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -6548,7 +6830,23 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\vservice_url\x18\x02 \x01(\tR\n" +
|
||||
"serviceUrl\x12\x16\n" +
|
||||
"\x06domain\x18\x03 \x01(\tR\x06domain\x12,\n" +
|
||||
"\x12port_auto_assigned\x18\x04 \x01(\bR\x10portAutoAssigned*b\n" +
|
||||
"\x12port_auto_assigned\x18\x04 \x01(\bR\x10portAutoAssigned\"\xd9\x01\n" +
|
||||
"\x13StartCaptureRequest\x12\x1f\n" +
|
||||
"\vtext_output\x18\x01 \x01(\bR\n" +
|
||||
"textOutput\x12\x19\n" +
|
||||
"\bsnap_len\x18\x02 \x01(\rR\asnapLen\x125\n" +
|
||||
"\bduration\x18\x03 \x01(\v2\x19.google.protobuf.DurationR\bduration\x12\x1f\n" +
|
||||
"\vfilter_expr\x18\x04 \x01(\tR\n" +
|
||||
"filterExpr\x12\x18\n" +
|
||||
"\averbose\x18\x05 \x01(\bR\averbose\x12\x14\n" +
|
||||
"\x05ascii\x18\x06 \x01(\bR\x05ascii\"#\n" +
|
||||
"\rCapturePacket\x12\x12\n" +
|
||||
"\x04data\x18\x01 \x01(\fR\x04data\"P\n" +
|
||||
"\x19StartBundleCaptureRequest\x123\n" +
|
||||
"\atimeout\x18\x01 \x01(\v2\x19.google.protobuf.DurationR\atimeout\"\x1c\n" +
|
||||
"\x1aStartBundleCaptureResponse\"\x1a\n" +
|
||||
"\x18StopBundleCaptureRequest\"\x1b\n" +
|
||||
"\x19StopBundleCaptureResponse*b\n" +
|
||||
"\bLogLevel\x12\v\n" +
|
||||
"\aUNKNOWN\x10\x00\x12\t\n" +
|
||||
"\x05PANIC\x10\x01\x12\t\n" +
|
||||
@@ -6566,7 +6864,7 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"EXPOSE_UDP\x10\x03\x12\x0e\n" +
|
||||
"\n" +
|
||||
"EXPOSE_TLS\x10\x042\xfc\x15\n" +
|
||||
"EXPOSE_TLS\x10\x042\xff\x17\n" +
|
||||
"\rDaemonService\x126\n" +
|
||||
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
|
||||
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
|
||||
@@ -6587,7 +6885,10 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"CleanState\x12\x19.daemon.CleanStateRequest\x1a\x1a.daemon.CleanStateResponse\"\x00\x12H\n" +
|
||||
"\vDeleteState\x12\x1a.daemon.DeleteStateRequest\x1a\x1b.daemon.DeleteStateResponse\"\x00\x12u\n" +
|
||||
"\x1aSetSyncResponsePersistence\x12).daemon.SetSyncResponsePersistenceRequest\x1a*.daemon.SetSyncResponsePersistenceResponse\"\x00\x12H\n" +
|
||||
"\vTracePacket\x12\x1a.daemon.TracePacketRequest\x1a\x1b.daemon.TracePacketResponse\"\x00\x12D\n" +
|
||||
"\vTracePacket\x12\x1a.daemon.TracePacketRequest\x1a\x1b.daemon.TracePacketResponse\"\x00\x12F\n" +
|
||||
"\fStartCapture\x12\x1b.daemon.StartCaptureRequest\x1a\x15.daemon.CapturePacket\"\x000\x01\x12]\n" +
|
||||
"\x12StartBundleCapture\x12!.daemon.StartBundleCaptureRequest\x1a\".daemon.StartBundleCaptureResponse\"\x00\x12Z\n" +
|
||||
"\x11StopBundleCapture\x12 .daemon.StopBundleCaptureRequest\x1a!.daemon.StopBundleCaptureResponse\"\x00\x12D\n" +
|
||||
"\x0fSubscribeEvents\x12\x18.daemon.SubscribeRequest\x1a\x13.daemon.SystemEvent\"\x000\x01\x12B\n" +
|
||||
"\tGetEvents\x12\x18.daemon.GetEventsRequest\x1a\x19.daemon.GetEventsResponse\"\x00\x12N\n" +
|
||||
"\rSwitchProfile\x12\x1c.daemon.SwitchProfileRequest\x1a\x1d.daemon.SwitchProfileResponse\"\x00\x12B\n" +
|
||||
@@ -6622,7 +6923,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
|
||||
}
|
||||
|
||||
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 5)
|
||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 93)
|
||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 99)
|
||||
var file_daemon_proto_goTypes = []any{
|
||||
(LogLevel)(0), // 0: daemon.LogLevel
|
||||
(ExposeProtocol)(0), // 1: daemon.ExposeProtocol
|
||||
@@ -6719,128 +7020,142 @@ var file_daemon_proto_goTypes = []any{
|
||||
(*ExposeServiceRequest)(nil), // 92: daemon.ExposeServiceRequest
|
||||
(*ExposeServiceEvent)(nil), // 93: daemon.ExposeServiceEvent
|
||||
(*ExposeServiceReady)(nil), // 94: daemon.ExposeServiceReady
|
||||
nil, // 95: daemon.Network.ResolvedIPsEntry
|
||||
(*PortInfo_Range)(nil), // 96: daemon.PortInfo.Range
|
||||
nil, // 97: daemon.SystemEvent.MetadataEntry
|
||||
(*durationpb.Duration)(nil), // 98: google.protobuf.Duration
|
||||
(*timestamppb.Timestamp)(nil), // 99: google.protobuf.Timestamp
|
||||
(*StartCaptureRequest)(nil), // 95: daemon.StartCaptureRequest
|
||||
(*CapturePacket)(nil), // 96: daemon.CapturePacket
|
||||
(*StartBundleCaptureRequest)(nil), // 97: daemon.StartBundleCaptureRequest
|
||||
(*StartBundleCaptureResponse)(nil), // 98: daemon.StartBundleCaptureResponse
|
||||
(*StopBundleCaptureRequest)(nil), // 99: daemon.StopBundleCaptureRequest
|
||||
(*StopBundleCaptureResponse)(nil), // 100: daemon.StopBundleCaptureResponse
|
||||
nil, // 101: daemon.Network.ResolvedIPsEntry
|
||||
(*PortInfo_Range)(nil), // 102: daemon.PortInfo.Range
|
||||
nil, // 103: daemon.SystemEvent.MetadataEntry
|
||||
(*durationpb.Duration)(nil), // 104: google.protobuf.Duration
|
||||
(*timestamppb.Timestamp)(nil), // 105: google.protobuf.Timestamp
|
||||
}
|
||||
var file_daemon_proto_depIdxs = []int32{
|
||||
2, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
|
||||
98, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
28, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
|
||||
99, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||
99, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||
98, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||
26, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
|
||||
23, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
|
||||
22, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
|
||||
21, // 9: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
|
||||
20, // 10: daemon.FullStatus.peers:type_name -> daemon.PeerState
|
||||
24, // 11: daemon.FullStatus.relays:type_name -> daemon.RelayState
|
||||
25, // 12: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
|
||||
58, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
|
||||
27, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
|
||||
34, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
|
||||
95, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
|
||||
96, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
|
||||
35, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
|
||||
35, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
|
||||
36, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
|
||||
0, // 21: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel
|
||||
0, // 22: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
|
||||
44, // 23: daemon.ListStatesResponse.states:type_name -> daemon.State
|
||||
53, // 24: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags
|
||||
55, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
|
||||
3, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
|
||||
4, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
|
||||
99, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
|
||||
97, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
|
||||
58, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
|
||||
98, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
71, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
|
||||
1, // 33: daemon.ExposeServiceRequest.protocol:type_name -> daemon.ExposeProtocol
|
||||
94, // 34: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady
|
||||
33, // 35: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
|
||||
8, // 36: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
||||
10, // 37: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
|
||||
12, // 38: daemon.DaemonService.Up:input_type -> daemon.UpRequest
|
||||
14, // 39: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
|
||||
16, // 40: daemon.DaemonService.Down:input_type -> daemon.DownRequest
|
||||
18, // 41: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
|
||||
29, // 42: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest
|
||||
31, // 43: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest
|
||||
31, // 44: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest
|
||||
5, // 45: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest
|
||||
38, // 46: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
|
||||
40, // 47: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
|
||||
42, // 48: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
|
||||
45, // 49: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
|
||||
47, // 50: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
|
||||
49, // 51: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
|
||||
51, // 52: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest
|
||||
54, // 53: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest
|
||||
57, // 54: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest
|
||||
59, // 55: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest
|
||||
61, // 56: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest
|
||||
63, // 57: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest
|
||||
65, // 58: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest
|
||||
67, // 59: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest
|
||||
69, // 60: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest
|
||||
72, // 61: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest
|
||||
74, // 62: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
|
||||
76, // 63: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
|
||||
78, // 64: daemon.DaemonService.TriggerUpdate:input_type -> daemon.TriggerUpdateRequest
|
||||
80, // 65: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
|
||||
82, // 66: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
|
||||
84, // 67: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
|
||||
86, // 68: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
|
||||
88, // 69: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
|
||||
6, // 70: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
|
||||
90, // 71: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
|
||||
92, // 72: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest
|
||||
9, // 73: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
||||
11, // 74: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
||||
13, // 75: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
||||
15, // 76: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
||||
17, // 77: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
||||
19, // 78: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
||||
30, // 79: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
|
||||
32, // 80: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||
32, // 81: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||
37, // 82: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
|
||||
39, // 83: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
||||
41, // 84: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
|
||||
43, // 85: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
||||
46, // 86: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
|
||||
48, // 87: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
|
||||
50, // 88: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
|
||||
52, // 89: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
|
||||
56, // 90: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
|
||||
58, // 91: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
|
||||
60, // 92: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
|
||||
62, // 93: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
|
||||
64, // 94: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
|
||||
66, // 95: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
|
||||
68, // 96: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
|
||||
70, // 97: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
|
||||
73, // 98: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
|
||||
75, // 99: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
|
||||
77, // 100: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
|
||||
79, // 101: daemon.DaemonService.TriggerUpdate:output_type -> daemon.TriggerUpdateResponse
|
||||
81, // 102: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
|
||||
83, // 103: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
|
||||
85, // 104: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
|
||||
87, // 105: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
|
||||
89, // 106: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
|
||||
7, // 107: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
|
||||
91, // 108: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
|
||||
93, // 109: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent
|
||||
73, // [73:110] is the sub-list for method output_type
|
||||
36, // [36:73] is the sub-list for method input_type
|
||||
36, // [36:36] is the sub-list for extension type_name
|
||||
36, // [36:36] is the sub-list for extension extendee
|
||||
0, // [0:36] is the sub-list for field type_name
|
||||
2, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
|
||||
104, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
28, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
|
||||
105, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||
105, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||
104, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||
26, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
|
||||
23, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
|
||||
22, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
|
||||
21, // 9: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
|
||||
20, // 10: daemon.FullStatus.peers:type_name -> daemon.PeerState
|
||||
24, // 11: daemon.FullStatus.relays:type_name -> daemon.RelayState
|
||||
25, // 12: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
|
||||
58, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
|
||||
27, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
|
||||
34, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
|
||||
101, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
|
||||
102, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
|
||||
35, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
|
||||
35, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
|
||||
36, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
|
||||
0, // 21: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel
|
||||
0, // 22: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
|
||||
44, // 23: daemon.ListStatesResponse.states:type_name -> daemon.State
|
||||
53, // 24: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags
|
||||
55, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
|
||||
3, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
|
||||
4, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
|
||||
105, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
|
||||
103, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
|
||||
58, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
|
||||
104, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
71, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
|
||||
1, // 33: daemon.ExposeServiceRequest.protocol:type_name -> daemon.ExposeProtocol
|
||||
94, // 34: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady
|
||||
104, // 35: daemon.StartCaptureRequest.duration:type_name -> google.protobuf.Duration
|
||||
104, // 36: daemon.StartBundleCaptureRequest.timeout:type_name -> google.protobuf.Duration
|
||||
33, // 37: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
|
||||
8, // 38: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
||||
10, // 39: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
|
||||
12, // 40: daemon.DaemonService.Up:input_type -> daemon.UpRequest
|
||||
14, // 41: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
|
||||
16, // 42: daemon.DaemonService.Down:input_type -> daemon.DownRequest
|
||||
18, // 43: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
|
||||
29, // 44: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest
|
||||
31, // 45: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest
|
||||
31, // 46: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest
|
||||
5, // 47: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest
|
||||
38, // 48: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
|
||||
40, // 49: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
|
||||
42, // 50: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
|
||||
45, // 51: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
|
||||
47, // 52: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
|
||||
49, // 53: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
|
||||
51, // 54: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest
|
||||
54, // 55: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest
|
||||
95, // 56: daemon.DaemonService.StartCapture:input_type -> daemon.StartCaptureRequest
|
||||
97, // 57: daemon.DaemonService.StartBundleCapture:input_type -> daemon.StartBundleCaptureRequest
|
||||
99, // 58: daemon.DaemonService.StopBundleCapture:input_type -> daemon.StopBundleCaptureRequest
|
||||
57, // 59: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest
|
||||
59, // 60: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest
|
||||
61, // 61: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest
|
||||
63, // 62: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest
|
||||
65, // 63: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest
|
||||
67, // 64: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest
|
||||
69, // 65: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest
|
||||
72, // 66: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest
|
||||
74, // 67: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
|
||||
76, // 68: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
|
||||
78, // 69: daemon.DaemonService.TriggerUpdate:input_type -> daemon.TriggerUpdateRequest
|
||||
80, // 70: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
|
||||
82, // 71: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
|
||||
84, // 72: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
|
||||
86, // 73: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
|
||||
88, // 74: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
|
||||
6, // 75: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
|
||||
90, // 76: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
|
||||
92, // 77: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest
|
||||
9, // 78: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
||||
11, // 79: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
||||
13, // 80: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
||||
15, // 81: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
||||
17, // 82: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
||||
19, // 83: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
||||
30, // 84: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
|
||||
32, // 85: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||
32, // 86: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||
37, // 87: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
|
||||
39, // 88: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
||||
41, // 89: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
|
||||
43, // 90: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
||||
46, // 91: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
|
||||
48, // 92: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
|
||||
50, // 93: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
|
||||
52, // 94: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
|
||||
56, // 95: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
|
||||
96, // 96: daemon.DaemonService.StartCapture:output_type -> daemon.CapturePacket
|
||||
98, // 97: daemon.DaemonService.StartBundleCapture:output_type -> daemon.StartBundleCaptureResponse
|
||||
100, // 98: daemon.DaemonService.StopBundleCapture:output_type -> daemon.StopBundleCaptureResponse
|
||||
58, // 99: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
|
||||
60, // 100: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
|
||||
62, // 101: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
|
||||
64, // 102: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
|
||||
66, // 103: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
|
||||
68, // 104: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
|
||||
70, // 105: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
|
||||
73, // 106: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
|
||||
75, // 107: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
|
||||
77, // 108: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
|
||||
79, // 109: daemon.DaemonService.TriggerUpdate:output_type -> daemon.TriggerUpdateResponse
|
||||
81, // 110: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
|
||||
83, // 111: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
|
||||
85, // 112: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
|
||||
87, // 113: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
|
||||
89, // 114: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
|
||||
7, // 115: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
|
||||
91, // 116: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
|
||||
93, // 117: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent
|
||||
78, // [78:118] is the sub-list for method output_type
|
||||
38, // [38:78] is the sub-list for method input_type
|
||||
38, // [38:38] is the sub-list for extension type_name
|
||||
38, // [38:38] is the sub-list for extension extendee
|
||||
0, // [0:38] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_daemon_proto_init() }
|
||||
@@ -6870,7 +7185,7 @@ func file_daemon_proto_init() {
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
|
||||
NumEnums: 5,
|
||||
NumMessages: 93,
|
||||
NumMessages: 99,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
|
||||
@@ -64,6 +64,17 @@ service DaemonService {
|
||||
|
||||
rpc TracePacket(TracePacketRequest) returns (TracePacketResponse) {}
|
||||
|
||||
// StartCapture begins streaming packet capture on the WireGuard interface.
|
||||
// Requires --enable-capture set at service install/reconfigure time.
|
||||
rpc StartCapture(StartCaptureRequest) returns (stream CapturePacket) {}
|
||||
|
||||
// StartBundleCapture begins capturing packets to a server-side temp file
|
||||
// for inclusion in the next debug bundle. Auto-stops after the given timeout.
|
||||
rpc StartBundleCapture(StartBundleCaptureRequest) returns (StartBundleCaptureResponse) {}
|
||||
|
||||
// StopBundleCapture stops the running bundle capture. Idempotent.
|
||||
rpc StopBundleCapture(StopBundleCaptureRequest) returns (StopBundleCaptureResponse) {}
|
||||
|
||||
rpc SubscribeEvents(SubscribeRequest) returns (stream SystemEvent) {}
|
||||
|
||||
rpc GetEvents(GetEventsRequest) returns (GetEventsResponse) {}
|
||||
@@ -848,3 +859,26 @@ message ExposeServiceReady {
|
||||
string domain = 3;
|
||||
bool port_auto_assigned = 4;
|
||||
}
|
||||
|
||||
message StartCaptureRequest {
|
||||
bool text_output = 1;
|
||||
uint32 snap_len = 2;
|
||||
google.protobuf.Duration duration = 3;
|
||||
string filter_expr = 4;
|
||||
bool verbose = 5;
|
||||
bool ascii = 6;
|
||||
}
|
||||
|
||||
message CapturePacket {
|
||||
bytes data = 1;
|
||||
}
|
||||
|
||||
message StartBundleCaptureRequest {
|
||||
// timeout auto-stops the capture after this duration.
|
||||
// Clamped to a server-side maximum (10 minutes). Zero or unset defaults to the maximum.
|
||||
google.protobuf.Duration timeout = 1;
|
||||
}
|
||||
|
||||
message StartBundleCaptureResponse {}
|
||||
message StopBundleCaptureRequest {}
|
||||
message StopBundleCaptureResponse {}
|
||||
|
||||
@@ -53,6 +53,14 @@ type DaemonServiceClient interface {
|
||||
// SetSyncResponsePersistence enables or disables sync response persistence
|
||||
SetSyncResponsePersistence(ctx context.Context, in *SetSyncResponsePersistenceRequest, opts ...grpc.CallOption) (*SetSyncResponsePersistenceResponse, error)
|
||||
TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error)
|
||||
// StartCapture begins streaming packet capture on the WireGuard interface.
|
||||
// Requires --enable-capture set at service install/reconfigure time.
|
||||
StartCapture(ctx context.Context, in *StartCaptureRequest, opts ...grpc.CallOption) (DaemonService_StartCaptureClient, error)
|
||||
// StartBundleCapture begins capturing packets to a server-side temp file
|
||||
// for inclusion in the next debug bundle. Auto-stops after the given timeout.
|
||||
StartBundleCapture(ctx context.Context, in *StartBundleCaptureRequest, opts ...grpc.CallOption) (*StartBundleCaptureResponse, error)
|
||||
// StopBundleCapture stops the running bundle capture. Idempotent.
|
||||
StopBundleCapture(ctx context.Context, in *StopBundleCaptureRequest, opts ...grpc.CallOption) (*StopBundleCaptureResponse, error)
|
||||
SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (DaemonService_SubscribeEventsClient, error)
|
||||
GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error)
|
||||
SwitchProfile(ctx context.Context, in *SwitchProfileRequest, opts ...grpc.CallOption) (*SwitchProfileResponse, error)
|
||||
@@ -253,8 +261,58 @@ func (c *daemonServiceClient) TracePacket(ctx context.Context, in *TracePacketRe
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) StartCapture(ctx context.Context, in *StartCaptureRequest, opts ...grpc.CallOption) (DaemonService_StartCaptureClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[0], "/daemon.DaemonService/StartCapture", opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x := &daemonServiceStartCaptureClient{stream}
|
||||
if err := x.ClientStream.SendMsg(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := x.ClientStream.CloseSend(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return x, nil
|
||||
}
|
||||
|
||||
type DaemonService_StartCaptureClient interface {
|
||||
Recv() (*CapturePacket, error)
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
type daemonServiceStartCaptureClient struct {
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
func (x *daemonServiceStartCaptureClient) Recv() (*CapturePacket, error) {
|
||||
m := new(CapturePacket)
|
||||
if err := x.ClientStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) StartBundleCapture(ctx context.Context, in *StartBundleCaptureRequest, opts ...grpc.CallOption) (*StartBundleCaptureResponse, error) {
|
||||
out := new(StartBundleCaptureResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/StartBundleCapture", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) StopBundleCapture(ctx context.Context, in *StopBundleCaptureRequest, opts ...grpc.CallOption) (*StopBundleCaptureResponse, error) {
|
||||
out := new(StopBundleCaptureResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/StopBundleCapture", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (DaemonService_SubscribeEventsClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[0], "/daemon.DaemonService/SubscribeEvents", opts...)
|
||||
stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[1], "/daemon.DaemonService/SubscribeEvents", opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -439,7 +497,7 @@ func (c *daemonServiceClient) GetInstallerResult(ctx context.Context, in *Instal
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (DaemonService_ExposeServiceClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[1], "/daemon.DaemonService/ExposeService", opts...)
|
||||
stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[2], "/daemon.DaemonService/ExposeService", opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -509,6 +567,14 @@ type DaemonServiceServer interface {
|
||||
// SetSyncResponsePersistence enables or disables sync response persistence
|
||||
SetSyncResponsePersistence(context.Context, *SetSyncResponsePersistenceRequest) (*SetSyncResponsePersistenceResponse, error)
|
||||
TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error)
|
||||
// StartCapture begins streaming packet capture on the WireGuard interface.
|
||||
// Requires --enable-capture set at service install/reconfigure time.
|
||||
StartCapture(*StartCaptureRequest, DaemonService_StartCaptureServer) error
|
||||
// StartBundleCapture begins capturing packets to a server-side temp file
|
||||
// for inclusion in the next debug bundle. Auto-stops after the given timeout.
|
||||
StartBundleCapture(context.Context, *StartBundleCaptureRequest) (*StartBundleCaptureResponse, error)
|
||||
// StopBundleCapture stops the running bundle capture. Idempotent.
|
||||
StopBundleCapture(context.Context, *StopBundleCaptureRequest) (*StopBundleCaptureResponse, error)
|
||||
SubscribeEvents(*SubscribeRequest, DaemonService_SubscribeEventsServer) error
|
||||
GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error)
|
||||
SwitchProfile(context.Context, *SwitchProfileRequest) (*SwitchProfileResponse, error)
|
||||
@@ -598,6 +664,15 @@ func (UnimplementedDaemonServiceServer) SetSyncResponsePersistence(context.Conte
|
||||
func (UnimplementedDaemonServiceServer) TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method TracePacket not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) StartCapture(*StartCaptureRequest, DaemonService_StartCaptureServer) error {
|
||||
return status.Errorf(codes.Unimplemented, "method StartCapture not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) StartBundleCapture(context.Context, *StartBundleCaptureRequest) (*StartBundleCaptureResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method StartBundleCapture not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) StopBundleCapture(context.Context, *StopBundleCaptureRequest) (*StopBundleCaptureResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method StopBundleCapture not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) SubscribeEvents(*SubscribeRequest, DaemonService_SubscribeEventsServer) error {
|
||||
return status.Errorf(codes.Unimplemented, "method SubscribeEvents not implemented")
|
||||
}
|
||||
@@ -992,6 +1067,63 @@ func _DaemonService_TracePacket_Handler(srv interface{}, ctx context.Context, de
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_StartCapture_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
m := new(StartCaptureRequest)
|
||||
if err := stream.RecvMsg(m); err != nil {
|
||||
return err
|
||||
}
|
||||
return srv.(DaemonServiceServer).StartCapture(m, &daemonServiceStartCaptureServer{stream})
|
||||
}
|
||||
|
||||
type DaemonService_StartCaptureServer interface {
|
||||
Send(*CapturePacket) error
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
type daemonServiceStartCaptureServer struct {
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
func (x *daemonServiceStartCaptureServer) Send(m *CapturePacket) error {
|
||||
return x.ServerStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func _DaemonService_StartBundleCapture_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(StartBundleCaptureRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).StartBundleCapture(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/StartBundleCapture",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).StartBundleCapture(ctx, req.(*StartBundleCaptureRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_StopBundleCapture_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(StopBundleCaptureRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).StopBundleCapture(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/StopBundleCapture",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).StopBundleCapture(ctx, req.(*StopBundleCaptureRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_SubscribeEvents_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
m := new(SubscribeRequest)
|
||||
if err := stream.RecvMsg(m); err != nil {
|
||||
@@ -1419,6 +1551,14 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "TracePacket",
|
||||
Handler: _DaemonService_TracePacket_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "StartBundleCapture",
|
||||
Handler: _DaemonService_StartBundleCapture_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "StopBundleCapture",
|
||||
Handler: _DaemonService_StopBundleCapture_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "GetEvents",
|
||||
Handler: _DaemonService_GetEvents_Handler,
|
||||
@@ -1489,6 +1629,11 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
StreamName: "StartCapture",
|
||||
Handler: _DaemonService_StartCapture_Handler,
|
||||
ServerStreams: true,
|
||||
},
|
||||
{
|
||||
StreamName: "SubscribeEvents",
|
||||
Handler: _DaemonService_SubscribeEvents_Handler,
|
||||
|
||||
365
client/server/capture.go
Normal file
365
client/server/capture.go
Normal file
@@ -0,0 +1,365 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/util/capture"
|
||||
)
|
||||
|
||||
const maxBundleCaptureDuration = 10 * time.Minute
|
||||
|
||||
// bundleCapture holds the state of an in-progress capture destined for the
|
||||
// debug bundle. The lifecycle is:
|
||||
//
|
||||
// StartBundleCapture → capture running, writing to temp file
|
||||
// StopBundleCapture → capture stopped, temp file available
|
||||
// DebugBundle → temp file included in zip, then cleaned up
|
||||
type bundleCapture struct {
|
||||
mu sync.Mutex
|
||||
sess *capture.Session
|
||||
file *os.File
|
||||
engine *internal.Engine
|
||||
cancel context.CancelFunc
|
||||
stopped bool
|
||||
}
|
||||
|
||||
// stop halts the capture session and closes the pcap writer. Idempotent.
|
||||
func (bc *bundleCapture) stop() {
|
||||
bc.mu.Lock()
|
||||
defer bc.mu.Unlock()
|
||||
|
||||
if bc.stopped {
|
||||
return
|
||||
}
|
||||
bc.stopped = true
|
||||
|
||||
if bc.cancel != nil {
|
||||
bc.cancel()
|
||||
}
|
||||
if bc.sess != nil {
|
||||
bc.sess.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// path returns the temp file path, or "" if no file exists.
|
||||
func (bc *bundleCapture) path() string {
|
||||
if bc.file == nil {
|
||||
return ""
|
||||
}
|
||||
return bc.file.Name()
|
||||
}
|
||||
|
||||
// cleanup removes the temp file.
|
||||
func (bc *bundleCapture) cleanup() {
|
||||
if bc.file == nil {
|
||||
return
|
||||
}
|
||||
name := bc.file.Name()
|
||||
if err := bc.file.Close(); err != nil {
|
||||
log.Debugf("close bundle capture file: %v", err)
|
||||
}
|
||||
if err := os.Remove(name); err != nil && !os.IsNotExist(err) {
|
||||
log.Debugf("remove bundle capture file: %v", err)
|
||||
}
|
||||
bc.file = nil
|
||||
}
|
||||
|
||||
// StartCapture streams a pcap or text packet capture over gRPC.
|
||||
// Gated by the --enable-capture service flag.
|
||||
func (s *Server) StartCapture(req *proto.StartCaptureRequest, stream proto.DaemonService_StartCaptureServer) error {
|
||||
if !s.captureEnabled {
|
||||
return status.Error(codes.PermissionDenied,
|
||||
"packet capture is disabled; reinstall or reconfigure the service with --enable-capture")
|
||||
}
|
||||
|
||||
if d := req.GetDuration(); d != nil && d.AsDuration() < 0 {
|
||||
return status.Error(codes.InvalidArgument, "duration must not be negative")
|
||||
}
|
||||
|
||||
matcher, err := parseCaptureFilter(req)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
|
||||
}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
opts := capture.Options{
|
||||
Matcher: matcher,
|
||||
SnapLen: req.GetSnapLen(),
|
||||
Verbose: req.GetVerbose(),
|
||||
ASCII: req.GetAscii(),
|
||||
}
|
||||
if req.GetTextOutput() {
|
||||
opts.TextOutput = pw
|
||||
} else {
|
||||
opts.Output = pw
|
||||
}
|
||||
|
||||
sess, err := capture.NewSession(opts)
|
||||
if err != nil {
|
||||
pw.Close()
|
||||
return status.Errorf(codes.Internal, "create capture session: %v", err)
|
||||
}
|
||||
|
||||
engine, err := s.claimCapture(sess)
|
||||
if err != nil {
|
||||
sess.Stop()
|
||||
pw.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
if err := engine.SetCapture(sess); err != nil {
|
||||
s.releaseCapture(sess)
|
||||
sess.Stop()
|
||||
pw.Close()
|
||||
return status.Errorf(codes.Internal, "set capture: %v", err)
|
||||
}
|
||||
|
||||
// Send an empty initial message to signal that the capture was accepted.
|
||||
// The client waits for this before printing the banner, so it must arrive
|
||||
// before any packet data.
|
||||
if err := stream.Send(&proto.CapturePacket{}); err != nil {
|
||||
s.clearCaptureIfOwner(sess, engine)
|
||||
sess.Stop()
|
||||
pw.Close()
|
||||
return status.Errorf(codes.Internal, "send initial message: %v", err)
|
||||
}
|
||||
|
||||
ctx := stream.Context()
|
||||
if d := req.GetDuration(); d != nil {
|
||||
if dur := d.AsDuration(); dur > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, dur)
|
||||
defer cancel()
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
s.clearCaptureIfOwner(sess, engine)
|
||||
sess.Stop()
|
||||
pw.Close()
|
||||
}()
|
||||
defer pr.Close()
|
||||
|
||||
log.Infof("packet capture started (text=%v, expr=%q)", req.GetTextOutput(), req.GetFilterExpr())
|
||||
defer func() {
|
||||
stats := sess.Stats()
|
||||
log.Infof("packet capture stopped: %d packets, %d bytes, %d dropped",
|
||||
stats.Packets, stats.Bytes, stats.Dropped)
|
||||
}()
|
||||
|
||||
return streamToGRPC(pr, stream)
|
||||
}
|
||||
|
||||
func streamToGRPC(r io.Reader, stream proto.DaemonService_StartCaptureServer) error {
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
n, readErr := r.Read(buf)
|
||||
if n > 0 {
|
||||
if err := stream.Send(&proto.CapturePacket{Data: buf[:n]}); err != nil {
|
||||
log.Debugf("capture stream send: %v", err)
|
||||
return nil //nolint:nilerr // client disconnected
|
||||
}
|
||||
}
|
||||
if readErr != nil {
|
||||
return nil //nolint:nilerr // pipe closed, capture stopped normally
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StartBundleCapture begins capturing packets to a server-side temp file for
|
||||
// inclusion in the next debug bundle. Not gated by --enable-capture since the
|
||||
// output stays on the server (same trust level as CPU profiling).
|
||||
//
|
||||
// A timeout auto-stops the capture as a safety net if StopBundleCapture is
|
||||
// never called (e.g. CLI crash).
|
||||
func (s *Server) StartBundleCapture(_ context.Context, req *proto.StartBundleCaptureRequest) (*proto.StartBundleCaptureResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
s.stopBundleCaptureLocked()
|
||||
s.cleanupBundleCapture()
|
||||
|
||||
if s.activeCapture != nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "another capture is already running")
|
||||
}
|
||||
|
||||
engine, err := s.getCaptureEngineLocked()
|
||||
if err != nil {
|
||||
// Not fatal: kernel mode or not connected. Log and return success
|
||||
// so the debug bundle still generates without capture data.
|
||||
log.Warnf("packet capture unavailable, skipping: %v", err)
|
||||
return &proto.StartBundleCaptureResponse{}, nil
|
||||
}
|
||||
|
||||
timeout := req.GetTimeout().AsDuration()
|
||||
if timeout <= 0 || timeout > maxBundleCaptureDuration {
|
||||
timeout = maxBundleCaptureDuration
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp("", "netbird.capture.*.pcap")
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "create temp file: %v", err)
|
||||
}
|
||||
|
||||
sess, err := capture.NewSession(capture.Options{Output: f})
|
||||
if err != nil {
|
||||
f.Close()
|
||||
os.Remove(f.Name())
|
||||
return nil, status.Errorf(codes.Internal, "create capture session: %v", err)
|
||||
}
|
||||
|
||||
if err := engine.SetCapture(sess); err != nil {
|
||||
sess.Stop()
|
||||
f.Close()
|
||||
os.Remove(f.Name())
|
||||
log.Warnf("packet capture unavailable (no filtered device), skipping: %v", err)
|
||||
return &proto.StartBundleCaptureResponse{}, nil
|
||||
}
|
||||
s.activeCapture = sess
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
bc := &bundleCapture{
|
||||
sess: sess,
|
||||
file: f,
|
||||
engine: engine,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
s.bundleCapture = bc
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
s.mutex.Lock()
|
||||
if s.bundleCapture == bc {
|
||||
s.stopBundleCaptureLocked()
|
||||
} else {
|
||||
bc.stop()
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
log.Infof("bundle capture auto-stopped after timeout")
|
||||
}()
|
||||
log.Infof("bundle capture started (timeout=%s, file=%s)", timeout, f.Name())
|
||||
|
||||
return &proto.StartBundleCaptureResponse{}, nil
|
||||
}
|
||||
|
||||
// StopBundleCapture stops the running bundle capture. Idempotent.
|
||||
func (s *Server) StopBundleCapture(_ context.Context, _ *proto.StopBundleCaptureRequest) (*proto.StopBundleCaptureResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
s.stopBundleCaptureLocked()
|
||||
return &proto.StopBundleCaptureResponse{}, nil
|
||||
}
|
||||
|
||||
// stopBundleCaptureLocked stops the bundle capture if running. Must hold s.mutex.
|
||||
func (s *Server) stopBundleCaptureLocked() {
|
||||
if s.bundleCapture == nil {
|
||||
return
|
||||
}
|
||||
bc := s.bundleCapture
|
||||
if bc.engine != nil && s.activeCapture == bc.sess {
|
||||
if err := bc.engine.SetCapture(nil); err != nil {
|
||||
log.Debugf("clear bundle capture: %v", err)
|
||||
}
|
||||
s.activeCapture = nil
|
||||
}
|
||||
bc.stop()
|
||||
|
||||
stats := bc.sess.Stats()
|
||||
log.Infof("bundle capture stopped: %d packets, %d bytes, %d dropped",
|
||||
stats.Packets, stats.Bytes, stats.Dropped)
|
||||
}
|
||||
|
||||
// bundleCapturePath returns the temp file path if a capture has been taken,
|
||||
// stops any running capture, and returns "". Called from DebugBundle.
|
||||
// Must hold s.mutex.
|
||||
func (s *Server) bundleCapturePath() string {
|
||||
if s.bundleCapture == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
s.bundleCapture.stop()
|
||||
return s.bundleCapture.path()
|
||||
}
|
||||
|
||||
// cleanupBundleCapture removes the temp file and clears state. Must hold s.mutex.
|
||||
func (s *Server) cleanupBundleCapture() {
|
||||
if s.bundleCapture == nil {
|
||||
return
|
||||
}
|
||||
s.bundleCapture.cleanup()
|
||||
s.bundleCapture = nil
|
||||
}
|
||||
|
||||
// claimCapture reserves the engine's capture slot for sess. Returns
|
||||
// FailedPrecondition if another capture is already active.
|
||||
func (s *Server) claimCapture(sess *capture.Session) (*internal.Engine, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.activeCapture != nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "another capture is already running")
|
||||
}
|
||||
engine, err := s.getCaptureEngineLocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.activeCapture = sess
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
// releaseCapture clears the active-capture owner if it still matches sess.
|
||||
func (s *Server) releaseCapture(sess *capture.Session) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
if s.activeCapture == sess {
|
||||
s.activeCapture = nil
|
||||
}
|
||||
}
|
||||
|
||||
// clearCaptureIfOwner clears engine's capture slot only if sess still owns it.
|
||||
func (s *Server) clearCaptureIfOwner(sess *capture.Session, engine *internal.Engine) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
if s.activeCapture != sess {
|
||||
return
|
||||
}
|
||||
if err := engine.SetCapture(nil); err != nil {
|
||||
log.Debugf("clear capture: %v", err)
|
||||
}
|
||||
s.activeCapture = nil
|
||||
}
|
||||
|
||||
func (s *Server) getCaptureEngineLocked() (*internal.Engine, error) {
|
||||
if s.connectClient == nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "client not connected")
|
||||
}
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "engine not initialized")
|
||||
}
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
// parseCaptureFilter returns a Matcher from the request.
|
||||
// Returns nil (match all) when no filter expression is set.
|
||||
func parseCaptureFilter(req *proto.StartCaptureRequest) (capture.Matcher, error) {
|
||||
expr := req.GetFilterExpr()
|
||||
if expr == "" {
|
||||
return nil, nil //nolint:nilnil // nil Matcher means "match all"
|
||||
}
|
||||
return capture.ParseFilter(expr)
|
||||
}
|
||||
@@ -43,7 +43,9 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
}()
|
||||
}
|
||||
|
||||
// Prepare refresh callback for health probes
|
||||
capturePath := s.bundleCapturePath()
|
||||
defer s.cleanupBundleCapture()
|
||||
|
||||
var refreshStatus func()
|
||||
if s.connectClient != nil {
|
||||
engine := s.connectClient.Engine()
|
||||
@@ -62,6 +64,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
SyncResponse: syncResponse,
|
||||
LogPath: s.logFile,
|
||||
CPUProfile: cpuProfileData,
|
||||
CapturePath: capturePath,
|
||||
RefreshStatus: refreshStatus,
|
||||
ClientMetrics: clientMetrics,
|
||||
},
|
||||
|
||||
@@ -33,6 +33,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updater"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/util/capture"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -89,7 +90,11 @@ type Server struct {
|
||||
profileManager *profilemanager.ServiceManager
|
||||
profilesDisabled bool
|
||||
updateSettingsDisabled bool
|
||||
networksDisabled bool
|
||||
captureEnabled bool
|
||||
bundleCapture *bundleCapture
|
||||
// activeCapture is the session currently installed on the engine; guarded by s.mutex.
|
||||
activeCapture *capture.Session
|
||||
networksDisabled bool
|
||||
|
||||
sleepHandler *sleephandler.SleepHandler
|
||||
|
||||
@@ -106,7 +111,7 @@ type oauthAuthFlow struct {
|
||||
}
|
||||
|
||||
// New server instance constructor.
|
||||
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool, networksDisabled bool) *Server {
|
||||
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool, captureEnabled bool, networksDisabled bool) *Server {
|
||||
s := &Server{
|
||||
rootCtx: ctx,
|
||||
logFile: logFile,
|
||||
@@ -115,6 +120,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
||||
profileManager: profilemanager.NewServiceManager(configFile),
|
||||
profilesDisabled: profilesDisabled,
|
||||
updateSettingsDisabled: updateSettingsDisabled,
|
||||
captureEnabled: captureEnabled,
|
||||
networksDisabled: networksDisabled,
|
||||
jwtCache: newJWTCache(),
|
||||
}
|
||||
|
||||
@@ -104,7 +104,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
|
||||
t.Fatalf("failed to set active profile state: %v", err)
|
||||
}
|
||||
|
||||
s := New(ctx, "debug", "", false, false, false)
|
||||
s := New(ctx, "debug", "", false, false, false, false)
|
||||
|
||||
s.config = config
|
||||
|
||||
@@ -165,7 +165,7 @@ func TestServer_Up(t *testing.T) {
|
||||
t.Fatalf("failed to set active profile state: %v", err)
|
||||
}
|
||||
|
||||
s := New(ctx, "console", "", false, false, false)
|
||||
s := New(ctx, "console", "", false, false, false, false)
|
||||
err = s.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -235,7 +235,7 @@ func TestServer_SubcribeEvents(t *testing.T) {
|
||||
t.Fatalf("failed to set active profile state: %v", err)
|
||||
}
|
||||
|
||||
s := New(ctx, "console", "", false, false, false)
|
||||
s := New(ctx, "console", "", false, false, false, false)
|
||||
|
||||
err = s.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -53,7 +53,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
s := New(ctx, "console", "", false, false, false)
|
||||
s := New(ctx, "console", "", false, false, false, false)
|
||||
|
||||
rosenpassEnabled := true
|
||||
rosenpassPermissive := true
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"fyne.io/fyne/v2/widget"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
@@ -38,6 +39,7 @@ type debugCollectionParams struct {
|
||||
upload bool
|
||||
uploadURL string
|
||||
enablePersistence bool
|
||||
capture bool
|
||||
}
|
||||
|
||||
// UI components for progress tracking
|
||||
@@ -51,25 +53,58 @@ type progressUI struct {
|
||||
func (s *serviceClient) showDebugUI() {
|
||||
w := s.app.NewWindow("NetBird Debug")
|
||||
w.SetOnClosed(s.cancel)
|
||||
|
||||
w.Resize(fyne.NewSize(600, 500))
|
||||
w.SetFixedSize(true)
|
||||
|
||||
anonymizeCheck := widget.NewCheck("Anonymize sensitive information (public IPs, domains, ...)", nil)
|
||||
systemInfoCheck := widget.NewCheck("Include system information (routes, interfaces, ...)", nil)
|
||||
systemInfoCheck.SetChecked(true)
|
||||
captureCheck := widget.NewCheck("Include packet capture", nil)
|
||||
uploadCheck := widget.NewCheck("Upload bundle automatically after creation", nil)
|
||||
uploadCheck.SetChecked(true)
|
||||
|
||||
uploadURLLabel := widget.NewLabel("Debug upload URL:")
|
||||
uploadURLContainer, uploadURL := s.buildUploadSection(uploadCheck)
|
||||
|
||||
debugModeContainer, runForDurationCheck, durationInput, noteLabel := s.buildDurationSection()
|
||||
|
||||
statusLabel := widget.NewLabel("")
|
||||
statusLabel.Hide()
|
||||
progressBar := widget.NewProgressBar()
|
||||
progressBar.Hide()
|
||||
createButton := widget.NewButton("Create Debug Bundle", nil)
|
||||
|
||||
uiControls := []fyne.Disableable{
|
||||
anonymizeCheck, systemInfoCheck, captureCheck,
|
||||
uploadCheck, uploadURL, runForDurationCheck, durationInput, createButton,
|
||||
}
|
||||
|
||||
createButton.OnTapped = s.getCreateHandler(
|
||||
statusLabel, progressBar, uploadCheck, uploadURL,
|
||||
anonymizeCheck, systemInfoCheck, captureCheck,
|
||||
runForDurationCheck, durationInput, uiControls, w,
|
||||
)
|
||||
|
||||
content := container.NewVBox(
|
||||
widget.NewLabel("Create a debug bundle to help troubleshoot issues with NetBird"),
|
||||
widget.NewLabel(""),
|
||||
anonymizeCheck, systemInfoCheck, captureCheck,
|
||||
uploadCheck, uploadURLContainer,
|
||||
widget.NewLabel(""),
|
||||
debugModeContainer, noteLabel,
|
||||
widget.NewLabel(""),
|
||||
statusLabel, progressBar, createButton,
|
||||
)
|
||||
|
||||
w.SetContent(container.NewPadded(content))
|
||||
w.Show()
|
||||
}
|
||||
|
||||
func (s *serviceClient) buildUploadSection(uploadCheck *widget.Check) (*fyne.Container, *widget.Entry) {
|
||||
uploadURL := widget.NewEntry()
|
||||
uploadURL.SetText(uptypes.DefaultBundleURL)
|
||||
uploadURL.SetPlaceHolder("Enter upload URL")
|
||||
|
||||
uploadURLContainer := container.NewVBox(
|
||||
uploadURLLabel,
|
||||
uploadURL,
|
||||
)
|
||||
uploadURLContainer := container.NewVBox(widget.NewLabel("Debug upload URL:"), uploadURL)
|
||||
|
||||
uploadCheck.OnChanged = func(checked bool) {
|
||||
if checked {
|
||||
@@ -78,13 +113,14 @@ func (s *serviceClient) showDebugUI() {
|
||||
uploadURLContainer.Hide()
|
||||
}
|
||||
}
|
||||
return uploadURLContainer, uploadURL
|
||||
}
|
||||
|
||||
debugModeContainer := container.NewHBox()
|
||||
func (s *serviceClient) buildDurationSection() (*fyne.Container, *widget.Check, *widget.Entry, *widget.Label) {
|
||||
runForDurationCheck := widget.NewCheck("Run with trace logs before creating bundle", nil)
|
||||
runForDurationCheck.SetChecked(true)
|
||||
|
||||
forLabel := widget.NewLabel("for")
|
||||
|
||||
durationInput := widget.NewEntry()
|
||||
durationInput.SetText("1")
|
||||
minutesLabel := widget.NewLabel("minute")
|
||||
@@ -108,63 +144,8 @@ func (s *serviceClient) showDebugUI() {
|
||||
}
|
||||
}
|
||||
|
||||
debugModeContainer.Add(runForDurationCheck)
|
||||
debugModeContainer.Add(forLabel)
|
||||
debugModeContainer.Add(durationInput)
|
||||
debugModeContainer.Add(minutesLabel)
|
||||
|
||||
statusLabel := widget.NewLabel("")
|
||||
statusLabel.Hide()
|
||||
|
||||
progressBar := widget.NewProgressBar()
|
||||
progressBar.Hide()
|
||||
|
||||
createButton := widget.NewButton("Create Debug Bundle", nil)
|
||||
|
||||
// UI controls that should be disabled during debug collection
|
||||
uiControls := []fyne.Disableable{
|
||||
anonymizeCheck,
|
||||
systemInfoCheck,
|
||||
uploadCheck,
|
||||
uploadURL,
|
||||
runForDurationCheck,
|
||||
durationInput,
|
||||
createButton,
|
||||
}
|
||||
|
||||
createButton.OnTapped = s.getCreateHandler(
|
||||
statusLabel,
|
||||
progressBar,
|
||||
uploadCheck,
|
||||
uploadURL,
|
||||
anonymizeCheck,
|
||||
systemInfoCheck,
|
||||
runForDurationCheck,
|
||||
durationInput,
|
||||
uiControls,
|
||||
w,
|
||||
)
|
||||
|
||||
content := container.NewVBox(
|
||||
widget.NewLabel("Create a debug bundle to help troubleshoot issues with NetBird"),
|
||||
widget.NewLabel(""),
|
||||
anonymizeCheck,
|
||||
systemInfoCheck,
|
||||
uploadCheck,
|
||||
uploadURLContainer,
|
||||
widget.NewLabel(""),
|
||||
debugModeContainer,
|
||||
noteLabel,
|
||||
widget.NewLabel(""),
|
||||
statusLabel,
|
||||
progressBar,
|
||||
createButton,
|
||||
)
|
||||
|
||||
paddedContent := container.NewPadded(content)
|
||||
w.SetContent(paddedContent)
|
||||
|
||||
w.Show()
|
||||
modeContainer := container.NewHBox(runForDurationCheck, forLabel, durationInput, minutesLabel)
|
||||
return modeContainer, runForDurationCheck, durationInput, noteLabel
|
||||
}
|
||||
|
||||
func validateMinute(s string, minutesLabel *widget.Label) error {
|
||||
@@ -200,6 +181,7 @@ func (s *serviceClient) getCreateHandler(
|
||||
uploadURL *widget.Entry,
|
||||
anonymizeCheck *widget.Check,
|
||||
systemInfoCheck *widget.Check,
|
||||
captureCheck *widget.Check,
|
||||
runForDurationCheck *widget.Check,
|
||||
duration *widget.Entry,
|
||||
uiControls []fyne.Disableable,
|
||||
@@ -222,6 +204,7 @@ func (s *serviceClient) getCreateHandler(
|
||||
params := &debugCollectionParams{
|
||||
anonymize: anonymizeCheck.Checked,
|
||||
systemInfo: systemInfoCheck.Checked,
|
||||
capture: captureCheck.Checked,
|
||||
upload: uploadCheck.Checked,
|
||||
uploadURL: url,
|
||||
enablePersistence: true,
|
||||
@@ -253,10 +236,7 @@ func (s *serviceClient) getCreateHandler(
|
||||
|
||||
statusLabel.SetText("Creating debug bundle...")
|
||||
go s.handleDebugCreation(
|
||||
anonymizeCheck.Checked,
|
||||
systemInfoCheck.Checked,
|
||||
uploadCheck.Checked,
|
||||
url,
|
||||
params,
|
||||
statusLabel,
|
||||
uiControls,
|
||||
w,
|
||||
@@ -371,7 +351,7 @@ func startProgressTracker(ctx context.Context, wg *sync.WaitGroup, duration time
|
||||
func (s *serviceClient) configureServiceForDebug(
|
||||
conn proto.DaemonServiceClient,
|
||||
state *debugInitialState,
|
||||
enablePersistence bool,
|
||||
params *debugCollectionParams,
|
||||
) {
|
||||
if state.wasDown {
|
||||
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
|
||||
@@ -397,7 +377,7 @@ func (s *serviceClient) configureServiceForDebug(
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
if enablePersistence {
|
||||
if params.enablePersistence {
|
||||
if _, err := conn.SetSyncResponsePersistence(s.ctx, &proto.SetSyncResponsePersistenceRequest{
|
||||
Enabled: true,
|
||||
}); err != nil {
|
||||
@@ -417,6 +397,26 @@ func (s *serviceClient) configureServiceForDebug(
|
||||
if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil {
|
||||
log.Warnf("failed to start CPU profiling: %v", err)
|
||||
}
|
||||
|
||||
s.startBundleCaptureIfEnabled(conn, params)
|
||||
}
|
||||
|
||||
func (s *serviceClient) startBundleCaptureIfEnabled(conn proto.DaemonServiceClient, params *debugCollectionParams) {
|
||||
if !params.capture {
|
||||
return
|
||||
}
|
||||
|
||||
const maxCapture = 10 * time.Minute
|
||||
timeout := params.duration + 30*time.Second
|
||||
if timeout > maxCapture {
|
||||
timeout = maxCapture
|
||||
log.Warnf("packet capture clamped to %s (server maximum)", maxCapture)
|
||||
}
|
||||
if _, err := conn.StartBundleCapture(s.ctx, &proto.StartBundleCaptureRequest{
|
||||
Timeout: durationpb.New(timeout),
|
||||
}); err != nil {
|
||||
log.Warnf("failed to start bundle capture: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceClient) collectDebugData(
|
||||
@@ -430,7 +430,7 @@ func (s *serviceClient) collectDebugData(
|
||||
var wg sync.WaitGroup
|
||||
startProgressTracker(ctx, &wg, params.duration, progress)
|
||||
|
||||
s.configureServiceForDebug(conn, state, params.enablePersistence)
|
||||
s.configureServiceForDebug(conn, state, params)
|
||||
|
||||
wg.Wait()
|
||||
progress.progressBar.Hide()
|
||||
@@ -440,6 +440,14 @@ func (s *serviceClient) collectDebugData(
|
||||
log.Warnf("failed to stop CPU profiling: %v", err)
|
||||
}
|
||||
|
||||
if params.capture {
|
||||
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if _, err := conn.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
|
||||
log.Warnf("failed to stop bundle capture: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -520,18 +528,37 @@ func handleError(progress *progressUI, errMsg string) {
|
||||
}
|
||||
|
||||
func (s *serviceClient) handleDebugCreation(
|
||||
anonymize bool,
|
||||
systemInfo bool,
|
||||
upload bool,
|
||||
uploadURL string,
|
||||
params *debugCollectionParams,
|
||||
statusLabel *widget.Label,
|
||||
uiControls []fyne.Disableable,
|
||||
w fyne.Window,
|
||||
) {
|
||||
log.Infof("Creating debug bundle (Anonymized: %v, System Info: %v, Upload Attempt: %v)...",
|
||||
anonymize, systemInfo, upload)
|
||||
conn, err := s.getSrvClient(failFastTimeout)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get client for debug: %v", err)
|
||||
statusLabel.SetText(fmt.Sprintf("Error: %v", err))
|
||||
enableUIControls(uiControls)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := s.createDebugBundle(anonymize, systemInfo, uploadURL)
|
||||
if params.capture {
|
||||
if _, err := conn.StartBundleCapture(s.ctx, &proto.StartBundleCaptureRequest{
|
||||
Timeout: durationpb.New(30 * time.Second),
|
||||
}); err != nil {
|
||||
log.Warnf("failed to start bundle capture: %v", err)
|
||||
} else {
|
||||
defer func() {
|
||||
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if _, err := conn.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
|
||||
log.Warnf("failed to stop bundle capture: %v", err)
|
||||
}
|
||||
}()
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := s.createDebugBundle(params.anonymize, params.systemInfo, params.uploadURL)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to create debug bundle: %v", err)
|
||||
statusLabel.SetText(fmt.Sprintf("Error creating bundle: %v", err))
|
||||
@@ -543,7 +570,7 @@ func (s *serviceClient) handleDebugCreation(
|
||||
uploadFailureReason := resp.GetUploadFailureReason()
|
||||
uploadedKey := resp.GetUploadedKey()
|
||||
|
||||
if upload {
|
||||
if params.upload {
|
||||
if uploadFailureReason != "" {
|
||||
showUploadFailedDialog(w, localPath, uploadFailureReason)
|
||||
} else {
|
||||
|
||||
@@ -5,6 +5,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
netbird "github.com/netbirdio/netbird/client/embed"
|
||||
sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
wasmcapture "github.com/netbirdio/netbird/client/wasm/internal/capture"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
|
||||
@@ -459,6 +461,95 @@ func createSetLogLevelMethod(client *netbird.Client) js.Func {
|
||||
})
|
||||
}
|
||||
|
||||
// createStartCaptureMethod creates the programmable packet capture method.
|
||||
// Returns a JS interface with onpacket callback and stop() method.
|
||||
//
|
||||
// Usage from JavaScript:
|
||||
//
|
||||
// const cap = await client.startCapture({ filter: "tcp port 443", verbose: true })
|
||||
// cap.onpacket = (line) => console.log(line)
|
||||
// const stats = cap.stop()
|
||||
func createStartCaptureMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
var opts js.Value
|
||||
if len(args) > 0 {
|
||||
opts = args[0]
|
||||
}
|
||||
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
iface, err := wasmcapture.Start(client, opts)
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(fmt.Sprintf("start capture: %v", err)))
|
||||
return
|
||||
}
|
||||
resolve.Invoke(iface)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// captureMethods returns capture() and stopCapture() that share state for
|
||||
// the console-log shortcut. capture() logs packets to the browser console
|
||||
// and stopCapture() ends it, like Ctrl+C on the CLI.
|
||||
//
|
||||
// Usage from browser devtools console:
|
||||
//
|
||||
// await client.capture() // capture all packets
|
||||
// await client.capture("tcp") // capture with filter
|
||||
// await client.capture({filter: "host 10.0.0.1", verbose: true})
|
||||
// client.stopCapture() // stop and print stats
|
||||
func captureMethods(client *netbird.Client) (startFn, stopFn js.Func) {
|
||||
var mu sync.Mutex
|
||||
var active *wasmcapture.Handle
|
||||
|
||||
startFn = js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
var opts js.Value
|
||||
if len(args) > 0 {
|
||||
opts = args[0]
|
||||
}
|
||||
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if active != nil {
|
||||
active.Stop()
|
||||
active = nil
|
||||
}
|
||||
|
||||
h, err := wasmcapture.StartConsole(client, opts)
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(fmt.Sprintf("start capture: %v", err)))
|
||||
return
|
||||
}
|
||||
active = h
|
||||
|
||||
console := js.Global().Get("console")
|
||||
console.Call("log", "[capture] started, call client.stopCapture() to stop")
|
||||
resolve.Invoke(js.Undefined())
|
||||
})
|
||||
})
|
||||
|
||||
stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if active == nil {
|
||||
js.Global().Get("console").Call("log", "[capture] no active capture")
|
||||
return js.Undefined()
|
||||
}
|
||||
|
||||
stats := active.Stop()
|
||||
active = nil
|
||||
|
||||
console := js.Global().Get("console")
|
||||
console.Call("log", fmt.Sprintf("[capture] stopped: %d packets, %d bytes, %d dropped",
|
||||
stats.Packets, stats.Bytes, stats.Dropped))
|
||||
return js.Undefined()
|
||||
})
|
||||
|
||||
return startFn, stopFn
|
||||
}
|
||||
|
||||
// createPromise is a helper to create JavaScript promises
|
||||
func createPromise(handler func(resolve, reject js.Value)) js.Value {
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
||||
@@ -521,6 +612,11 @@ func createClientObject(client *netbird.Client) js.Value {
|
||||
obj["statusDetail"] = createStatusDetailMethod(client)
|
||||
obj["getSyncResponse"] = createGetSyncResponseMethod(client)
|
||||
obj["setLogLevel"] = createSetLogLevelMethod(client)
|
||||
obj["startCapture"] = createStartCaptureMethod(client)
|
||||
|
||||
capStart, capStop := captureMethods(client)
|
||||
obj["capture"] = capStart
|
||||
obj["stopCapture"] = capStop
|
||||
|
||||
return js.ValueOf(obj)
|
||||
}
|
||||
|
||||
176
client/wasm/internal/capture/capture.go
Normal file
176
client/wasm/internal/capture/capture.go
Normal file
@@ -0,0 +1,176 @@
|
||||
//go:build js
|
||||
|
||||
// Package capture bridges the util/capture package to JavaScript via syscall/js.
|
||||
package capture
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall/js"
|
||||
|
||||
netbird "github.com/netbirdio/netbird/client/embed"
|
||||
)
|
||||
|
||||
// Handle holds a running capture session so it can be stopped later.
|
||||
type Handle struct {
|
||||
cs *netbird.CaptureSession
|
||||
stopFn js.Func
|
||||
stopped bool
|
||||
}
|
||||
|
||||
// Stop ends the capture and returns stats.
|
||||
func (h *Handle) Stop() netbird.CaptureStats {
|
||||
if h.stopped {
|
||||
return h.cs.Stats()
|
||||
}
|
||||
h.stopped = true
|
||||
h.stopFn.Release()
|
||||
|
||||
h.cs.Stop()
|
||||
return h.cs.Stats()
|
||||
}
|
||||
|
||||
func statsToJS(s netbird.CaptureStats) js.Value {
|
||||
obj := js.Global().Get("Object").Call("create", js.Null())
|
||||
obj.Set("packets", js.ValueOf(s.Packets))
|
||||
obj.Set("bytes", js.ValueOf(s.Bytes))
|
||||
obj.Set("dropped", js.ValueOf(s.Dropped))
|
||||
return obj
|
||||
}
|
||||
|
||||
// parseOpts extracts filter/verbose/ascii from a JS options value.
|
||||
func parseOpts(jsOpts js.Value) (filter string, verbose, ascii bool) {
|
||||
if jsOpts.IsNull() || jsOpts.IsUndefined() {
|
||||
return
|
||||
}
|
||||
if jsOpts.Type() == js.TypeString {
|
||||
filter = jsOpts.String()
|
||||
return
|
||||
}
|
||||
if jsOpts.Type() != js.TypeObject {
|
||||
return
|
||||
}
|
||||
if f := jsOpts.Get("filter"); !f.IsUndefined() && !f.IsNull() {
|
||||
filter = f.String()
|
||||
}
|
||||
if v := jsOpts.Get("verbose"); !v.IsUndefined() {
|
||||
verbose = v.Truthy()
|
||||
}
|
||||
if a := jsOpts.Get("ascii"); !a.IsUndefined() {
|
||||
ascii = a.Truthy()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Start creates a capture session and returns a JS interface for streaming text
|
||||
// output. The returned object exposes:
|
||||
//
|
||||
// onpacket(callback) - set callback(string) for each text line
|
||||
// stop() - stop capture and return stats { packets, bytes, dropped }
|
||||
//
|
||||
// Options: { filter: string, verbose: bool, ascii: bool } or just a filter string.
|
||||
func Start(client *netbird.Client, jsOpts js.Value) (js.Value, error) {
|
||||
filter, verbose, ascii := parseOpts(jsOpts)
|
||||
|
||||
cb := &jsCallbackWriter{}
|
||||
|
||||
cs, err := client.StartCapture(netbird.CaptureOptions{
|
||||
TextOutput: cb,
|
||||
Filter: filter,
|
||||
Verbose: verbose,
|
||||
ASCII: ascii,
|
||||
})
|
||||
if err != nil {
|
||||
return js.Undefined(), err
|
||||
}
|
||||
|
||||
handle := &Handle{cs: cs}
|
||||
|
||||
iface := js.Global().Get("Object").Call("create", js.Null())
|
||||
handle.stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any {
|
||||
return statsToJS(handle.Stop())
|
||||
})
|
||||
iface.Set("stop", handle.stopFn)
|
||||
iface.Set("onpacket", js.Undefined())
|
||||
cb.setInterface(iface)
|
||||
|
||||
return iface, nil
|
||||
}
|
||||
|
||||
// StartConsole starts a capture that logs every packet line to console.log.
|
||||
// Returns a Handle so the caller can stop it later.
|
||||
func StartConsole(client *netbird.Client, jsOpts js.Value) (*Handle, error) {
|
||||
filter, verbose, ascii := parseOpts(jsOpts)
|
||||
|
||||
cb := &jsCallbackWriter{}
|
||||
|
||||
cs, err := client.StartCapture(netbird.CaptureOptions{
|
||||
TextOutput: cb,
|
||||
Filter: filter,
|
||||
Verbose: verbose,
|
||||
ASCII: ascii,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
handle := &Handle{cs: cs}
|
||||
handle.stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any {
|
||||
return statsToJS(handle.Stop())
|
||||
})
|
||||
|
||||
iface := js.Global().Get("Object").Call("create", js.Null())
|
||||
console := js.Global().Get("console")
|
||||
iface.Set("onpacket", console.Get("log").Call("bind", console, js.ValueOf("[capture]")))
|
||||
cb.setInterface(iface)
|
||||
|
||||
return handle, nil
|
||||
}
|
||||
|
||||
// jsCallbackWriter is an io.Writer that buffers text until a newline, then
|
||||
// invokes the JS onpacket callback with each complete line.
|
||||
type jsCallbackWriter struct {
|
||||
mu sync.Mutex
|
||||
iface js.Value
|
||||
buf strings.Builder
|
||||
}
|
||||
|
||||
func (w *jsCallbackWriter) setInterface(iface js.Value) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.iface = iface
|
||||
}
|
||||
|
||||
func (w *jsCallbackWriter) Write(p []byte) (int, error) {
|
||||
w.mu.Lock()
|
||||
w.buf.Write(p)
|
||||
|
||||
var lines []string
|
||||
for {
|
||||
str := w.buf.String()
|
||||
idx := strings.IndexByte(str, '\n')
|
||||
if idx < 0 {
|
||||
break
|
||||
}
|
||||
lines = append(lines, str[:idx])
|
||||
w.buf.Reset()
|
||||
if idx+1 < len(str) {
|
||||
w.buf.WriteString(str[idx+1:])
|
||||
}
|
||||
}
|
||||
|
||||
iface := w.iface
|
||||
w.mu.Unlock()
|
||||
|
||||
if iface.IsUndefined() {
|
||||
return len(p), nil
|
||||
}
|
||||
cb := iface.Get("onpacket")
|
||||
if cb.IsUndefined() || cb.IsNull() {
|
||||
return len(p), nil
|
||||
}
|
||||
for _, line := range lines {
|
||||
cb.Invoke(js.ValueOf(line))
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
@@ -119,6 +119,8 @@ server:
|
||||
|
||||
# Reverse proxy settings (optional)
|
||||
# reverseProxy:
|
||||
# trustedHTTPProxies: []
|
||||
# trustedHTTPProxiesCount: 0
|
||||
# trustedPeers: []
|
||||
# trustedHTTPProxies: [] # CIDRs of trusted reverse proxies (e.g. ["10.0.0.0/8"])
|
||||
# trustedHTTPProxiesCount: 0 # Number of trusted proxies in front of the server (alternative to trustedHTTPProxies)
|
||||
# trustedPeers: [] # CIDRs of trusted peer networks (e.g. ["100.64.0.0/10"])
|
||||
# accessLogRetentionDays: 7 # Days to retain HTTP access logs. 0 (or unset) defaults to 7. Negative values disable cleanup (logs kept indefinitely).
|
||||
# accessLogCleanupIntervalHours: 24 # How often (in hours) to run the access-log cleanup job. 0 (or unset) is treated as "not set" and defaults to 24 hours; cleanup remains enabled. To disable cleanup, set accessLogRetentionDays to a negative value.
|
||||
|
||||
@@ -457,6 +457,18 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
|
||||
|
||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Cleanups run LIFO: the goroutine-drain registered here runs after Close below,
|
||||
// which is when Receive has actually returned. Without this, the Receive goroutine
|
||||
// can outlive the test and call t.Logf after teardown, panicking.
|
||||
receiveDone := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
select {
|
||||
case <-receiveDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("Receive goroutine did not exit after Close")
|
||||
}
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err, "failed to close flow")
|
||||
@@ -468,6 +480,7 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
|
||||
receivedAfterReconnect := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(receiveDone)
|
||||
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||
if msg.IsInitiator || len(msg.EventId) == 0 {
|
||||
return nil
|
||||
|
||||
2
go.mod
2
go.mod
@@ -323,3 +323,5 @@ replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184
|
||||
replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944
|
||||
|
||||
replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0
|
||||
|
||||
replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0
|
||||
|
||||
4
go.sum
4
go.sum
@@ -400,8 +400,6 @@ github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tA
|
||||
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k=
|
||||
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
||||
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4=
|
||||
github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
|
||||
github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU=
|
||||
github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To=
|
||||
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
|
||||
@@ -449,6 +447,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U=
|
||||
github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU=
|
||||
github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus=
|
||||
github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk=
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
||||
|
||||
@@ -472,7 +472,7 @@ start_services_and_show_instructions() {
|
||||
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
|
||||
echo "Registering CrowdSec bouncer..."
|
||||
local cs_retries=0
|
||||
while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli capi status >/dev/null 2>&1; do
|
||||
while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli lapi status >/dev/null 2>&1; do
|
||||
cs_retries=$((cs_retries + 1))
|
||||
if [[ $cs_retries -ge 30 ]]; then
|
||||
echo "WARNING: CrowdSec did not become ready. Skipping CrowdSec setup." > /dev/stderr
|
||||
|
||||
@@ -30,6 +30,7 @@ import (
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
@@ -109,7 +110,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
||||
|
||||
func (s *BaseServer) APIHandler() http.Handler {
|
||||
return Create(s, func() http.Handler {
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -117,6 +118,15 @@ func (s *BaseServer) APIHandler() http.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
|
||||
return Create(s, func() *middleware.APIRateLimiter {
|
||||
cfg, enabled := middleware.RateLimiterConfigFromEnv()
|
||||
limiter := middleware.NewAPIRateLimiter(cfg)
|
||||
limiter.SetEnabled(enabled)
|
||||
return limiter
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
return Create(s, func() *grpc.Server {
|
||||
trustedPeers := s.Config.ReverseProxy.TrustedPeers
|
||||
|
||||
@@ -2311,6 +2311,29 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetExpiredPeers_SkipsAlreadyExpired(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
testStore, cleanUp, err := store.NewTestStoreFromSQL(ctx, "testdata/store_with_expired_peers.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
// Verify the already-expired peer is excluded at the store level
|
||||
peers, err := testStore.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, peer := range peers {
|
||||
assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should be excluded by the store query")
|
||||
assert.False(t, peer.Status.LoginExpired, "returned peers should not already be marked as login expired")
|
||||
}
|
||||
|
||||
// Only the non-expired peer with expiration enabled should be returned
|
||||
require.Len(t, peers, 1)
|
||||
assert.Equal(t, "notexpired01", peers[0].ID)
|
||||
}
|
||||
|
||||
func TestAccount_GetInactivePeers(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
|
||||
@@ -5,9 +5,6 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/rs/cors"
|
||||
@@ -66,14 +63,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
apiPrefix = "/api"
|
||||
rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED"
|
||||
rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST"
|
||||
rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM"
|
||||
apiPrefix = "/api"
|
||||
)
|
||||
|
||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) {
|
||||
|
||||
// Register bypass paths for unauthenticated endpoints
|
||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||
@@ -94,34 +88,10 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
||||
}
|
||||
|
||||
var rateLimitingConfig *middleware.RateLimiterConfig
|
||||
if os.Getenv(rateLimitingEnabledKey) == "true" {
|
||||
rpm := 6
|
||||
if v := os.Getenv(rateLimitingRPMKey); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm)
|
||||
} else {
|
||||
rpm = value
|
||||
}
|
||||
}
|
||||
|
||||
burst := 500
|
||||
if v := os.Getenv(rateLimitingBurstKey); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst)
|
||||
} else {
|
||||
burst = value
|
||||
}
|
||||
}
|
||||
|
||||
rateLimitingConfig = &middleware.RateLimiterConfig{
|
||||
RequestsPerMinute: float64(rpm),
|
||||
Burst: burst,
|
||||
CleanupInterval: 6 * time.Hour,
|
||||
LimiterTTL: 24 * time.Hour,
|
||||
}
|
||||
if rateLimiter == nil {
|
||||
log.Warn("NewAPIHandler: nil rate limiter, rate limiting disabled")
|
||||
rateLimiter = middleware.NewAPIRateLimiter(nil)
|
||||
rateLimiter.SetEnabled(false)
|
||||
}
|
||||
|
||||
authMiddleware := middleware.NewAuthMiddleware(
|
||||
@@ -129,7 +99,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
accountManager.GetAccountIDFromUserAuth,
|
||||
accountManager.SyncUserJWTGroups,
|
||||
accountManager.GetUserFromUserAuth,
|
||||
rateLimitingConfig,
|
||||
rateLimiter,
|
||||
appMetrics.GetMeter(),
|
||||
)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
@@ -42,14 +43,9 @@ func NewAuthMiddleware(
|
||||
ensureAccount EnsureAccountFunc,
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc,
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||
rateLimiterConfig *RateLimiterConfig,
|
||||
rateLimiter *APIRateLimiter,
|
||||
meter metric.Meter,
|
||||
) *AuthMiddleware {
|
||||
var rateLimiter *APIRateLimiter
|
||||
if rateLimiterConfig != nil {
|
||||
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
|
||||
}
|
||||
|
||||
var patUsageTracker *PATUsageTracker
|
||||
if meter != nil {
|
||||
var err error
|
||||
@@ -87,17 +83,14 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
|
||||
switch authType {
|
||||
case "bearer":
|
||||
request, err := m.checkJWTFromRequest(r, authHeader)
|
||||
if err != nil {
|
||||
if err := m.checkJWTFromRequest(r, authHeader); err != nil {
|
||||
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
return
|
||||
}
|
||||
|
||||
h.ServeHTTP(w, request)
|
||||
h.ServeHTTP(w, r)
|
||||
case "token":
|
||||
request, err := m.checkPATFromRequest(r, authHeader)
|
||||
if err != nil {
|
||||
if err := m.checkPATFromRequest(r, authHeader); err != nil {
|
||||
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
||||
// Check if it's a status error, otherwise default to Unauthorized
|
||||
if _, ok := status.FromError(err); !ok {
|
||||
@@ -106,7 +99,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(w, request)
|
||||
h.ServeHTTP(w, r)
|
||||
default:
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
|
||||
return
|
||||
@@ -115,19 +108,19 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
// CheckJWTFromRequest checks if the JWT is valid
|
||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) error {
|
||||
token, err := getTokenFromJWTRequest(authHeaderParts)
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
if err != nil {
|
||||
return r, fmt.Errorf("error extracting token: %w", err)
|
||||
return fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
|
||||
if err != nil {
|
||||
return r, err
|
||||
return err
|
||||
}
|
||||
|
||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||
@@ -143,7 +136,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||
accountId, _, err := m.ensureAccount(ctx, userAuth)
|
||||
if err != nil {
|
||||
return r, err
|
||||
return err
|
||||
}
|
||||
|
||||
if userAuth.AccountId != accountId {
|
||||
@@ -153,7 +146,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
||||
|
||||
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
|
||||
if err != nil {
|
||||
return r, err
|
||||
return err
|
||||
}
|
||||
|
||||
err = m.syncUserJWTGroups(ctx, userAuth)
|
||||
@@ -164,41 +157,41 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
||||
_, err = m.getUserFromUserAuth(ctx, userAuth)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
|
||||
return r, err
|
||||
return err
|
||||
}
|
||||
|
||||
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
||||
// propagates ctx change to upstream middleware
|
||||
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) error {
|
||||
token, err := getTokenFromPATRequest(authHeaderParts)
|
||||
if err != nil {
|
||||
return r, fmt.Errorf("error extracting token: %w", err)
|
||||
return fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
if m.patUsageTracker != nil {
|
||||
m.patUsageTracker.IncrementUsage(token)
|
||||
}
|
||||
|
||||
if m.rateLimiter != nil && !isTerraformRequest(r) {
|
||||
if !m.rateLimiter.Allow(token) {
|
||||
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
||||
}
|
||||
if !isTerraformRequest(r) && !m.rateLimiter.Allow(token) {
|
||||
return status.Errorf(status.TooManyRequests, "too many requests")
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
|
||||
if err != nil {
|
||||
return r, fmt.Errorf("invalid Token: %w", err)
|
||||
return fmt.Errorf("invalid Token: %w", err)
|
||||
}
|
||||
if time.Now().After(pat.GetExpirationDate()) {
|
||||
return r, fmt.Errorf("token expired")
|
||||
return fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
err = m.authManager.MarkPATUsed(ctx, pat.ID)
|
||||
if err != nil {
|
||||
return r, err
|
||||
return err
|
||||
}
|
||||
|
||||
userAuth := auth.UserAuth{
|
||||
@@ -216,7 +209,9 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
}
|
||||
}
|
||||
|
||||
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
||||
// propagates ctx change to upstream middleware
|
||||
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
|
||||
return nil
|
||||
}
|
||||
|
||||
func isTerraformRequest(r *http.Request) bool {
|
||||
|
||||
@@ -196,6 +196,8 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
||||
}
|
||||
|
||||
disabledLimiter := NewAPIRateLimiter(nil)
|
||||
disabledLimiter.SetEnabled(false)
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
@@ -207,7 +209,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
nil,
|
||||
disabledLimiter,
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -266,7 +268,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -318,7 +320,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -361,7 +363,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -405,7 +407,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -469,7 +471,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -528,7 +530,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -583,7 +585,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -670,6 +672,8 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
||||
}
|
||||
|
||||
disabledLimiter := NewAPIRateLimiter(nil)
|
||||
disabledLimiter.SetEnabled(false)
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
@@ -681,7 +685,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
nil,
|
||||
disabledLimiter,
|
||||
nil,
|
||||
)
|
||||
|
||||
|
||||
@@ -4,14 +4,27 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
const (
|
||||
RateLimitingEnabledEnv = "NB_API_RATE_LIMITING_ENABLED"
|
||||
RateLimitingBurstEnv = "NB_API_RATE_LIMITING_BURST"
|
||||
RateLimitingRPMEnv = "NB_API_RATE_LIMITING_RPM"
|
||||
|
||||
defaultAPIRPM = 6
|
||||
defaultAPIBurst = 500
|
||||
)
|
||||
|
||||
// RateLimiterConfig holds configuration for the API rate limiter
|
||||
type RateLimiterConfig struct {
|
||||
// RequestsPerMinute defines the rate at which tokens are replenished
|
||||
@@ -34,6 +47,43 @@ func DefaultRateLimiterConfig() *RateLimiterConfig {
|
||||
}
|
||||
}
|
||||
|
||||
func RateLimiterConfigFromEnv() (cfg *RateLimiterConfig, enabled bool) {
|
||||
rpm := defaultAPIRPM
|
||||
if v := os.Getenv(RateLimitingRPMEnv); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingRPMEnv, err, rpm)
|
||||
} else {
|
||||
rpm = value
|
||||
}
|
||||
}
|
||||
if rpm <= 0 {
|
||||
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingRPMEnv, rpm, defaultAPIRPM)
|
||||
rpm = defaultAPIRPM
|
||||
}
|
||||
|
||||
burst := defaultAPIBurst
|
||||
if v := os.Getenv(RateLimitingBurstEnv); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingBurstEnv, err, burst)
|
||||
} else {
|
||||
burst = value
|
||||
}
|
||||
}
|
||||
if burst <= 0 {
|
||||
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingBurstEnv, burst, defaultAPIBurst)
|
||||
burst = defaultAPIBurst
|
||||
}
|
||||
|
||||
return &RateLimiterConfig{
|
||||
RequestsPerMinute: float64(rpm),
|
||||
Burst: burst,
|
||||
CleanupInterval: 6 * time.Hour,
|
||||
LimiterTTL: 24 * time.Hour,
|
||||
}, os.Getenv(RateLimitingEnabledEnv) == "true"
|
||||
}
|
||||
|
||||
// limiterEntry holds a rate limiter and its last access time
|
||||
type limiterEntry struct {
|
||||
limiter *rate.Limiter
|
||||
@@ -46,6 +96,7 @@ type APIRateLimiter struct {
|
||||
limiters map[string]*limiterEntry
|
||||
mu sync.RWMutex
|
||||
stopChan chan struct{}
|
||||
enabled atomic.Bool
|
||||
}
|
||||
|
||||
// NewAPIRateLimiter creates a new API rate limiter with the given configuration
|
||||
@@ -59,14 +110,53 @@ func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter {
|
||||
limiters: make(map[string]*limiterEntry),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
rl.enabled.Store(true)
|
||||
|
||||
go rl.cleanupLoop()
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
func (rl *APIRateLimiter) SetEnabled(enabled bool) {
|
||||
rl.enabled.Store(enabled)
|
||||
}
|
||||
|
||||
func (rl *APIRateLimiter) Enabled() bool {
|
||||
return rl.enabled.Load()
|
||||
}
|
||||
|
||||
func (rl *APIRateLimiter) UpdateConfig(config *RateLimiterConfig) {
|
||||
if config == nil {
|
||||
return
|
||||
}
|
||||
if config.RequestsPerMinute <= 0 || config.Burst <= 0 {
|
||||
log.Warnf("UpdateConfig: ignoring invalid rpm=%v burst=%d", config.RequestsPerMinute, config.Burst)
|
||||
return
|
||||
}
|
||||
|
||||
newRPS := rate.Limit(config.RequestsPerMinute / 60.0)
|
||||
newBurst := config.Burst
|
||||
|
||||
rl.mu.Lock()
|
||||
rl.config.RequestsPerMinute = config.RequestsPerMinute
|
||||
rl.config.Burst = newBurst
|
||||
snapshot := make([]*rate.Limiter, 0, len(rl.limiters))
|
||||
for _, entry := range rl.limiters {
|
||||
snapshot = append(snapshot, entry.limiter)
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
|
||||
for _, l := range snapshot {
|
||||
l.SetLimit(newRPS)
|
||||
l.SetBurst(newBurst)
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a request for the given key (token) is allowed
|
||||
func (rl *APIRateLimiter) Allow(key string) bool {
|
||||
if !rl.enabled.Load() {
|
||||
return true
|
||||
}
|
||||
limiter := rl.getLimiter(key)
|
||||
return limiter.Allow()
|
||||
}
|
||||
@@ -74,6 +164,9 @@ func (rl *APIRateLimiter) Allow(key string) bool {
|
||||
// Wait blocks until the rate limiter allows another request for the given key
|
||||
// Returns an error if the context is canceled
|
||||
func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
|
||||
if !rl.enabled.Load() {
|
||||
return nil
|
||||
}
|
||||
limiter := rl.getLimiter(key)
|
||||
return limiter.Wait(ctx)
|
||||
}
|
||||
@@ -153,6 +246,10 @@ func (rl *APIRateLimiter) Reset(key string) {
|
||||
// Returns 429 Too Many Requests if the rate limit is exceeded.
|
||||
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !rl.enabled.Load() {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
clientIP := getClientIP(r)
|
||||
if !rl.Allow(clientIP) {
|
||||
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -156,3 +158,172 @@ func TestAPIRateLimiter_Reset(t *testing.T) {
|
||||
// Should be allowed again
|
||||
assert.True(t, rl.Allow("test-key"))
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_SetEnabled(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
assert.True(t, rl.Allow("key"))
|
||||
assert.False(t, rl.Allow("key"), "burst exhausted while enabled")
|
||||
|
||||
rl.SetEnabled(false)
|
||||
assert.False(t, rl.Enabled())
|
||||
for i := 0; i < 5; i++ {
|
||||
assert.True(t, rl.Allow("key"), "disabled limiter must always allow")
|
||||
}
|
||||
|
||||
rl.SetEnabled(true)
|
||||
assert.True(t, rl.Enabled())
|
||||
assert.False(t, rl.Allow("key"), "re-enabled limiter retains prior bucket state")
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_UpdateConfig(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 2,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
assert.True(t, rl.Allow("k1"))
|
||||
assert.True(t, rl.Allow("k1"))
|
||||
assert.False(t, rl.Allow("k1"), "burst=2 exhausted")
|
||||
|
||||
rl.UpdateConfig(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 10,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
|
||||
// New burst applies to existing keys in place; bucket refills up to new burst over time,
|
||||
// but importantly newly-added keys use the updated config immediately.
|
||||
assert.True(t, rl.Allow("k2"))
|
||||
for i := 0; i < 9; i++ {
|
||||
assert.True(t, rl.Allow("k2"))
|
||||
}
|
||||
assert.False(t, rl.Allow("k2"), "new burst=10 exhausted")
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_UpdateConfig_NilIgnored(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
rl.UpdateConfig(nil) // must not panic or zero the config
|
||||
|
||||
assert.True(t, rl.Allow("k"))
|
||||
assert.False(t, rl.Allow("k"))
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_UpdateConfig_NonPositiveIgnored(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
assert.True(t, rl.Allow("k"))
|
||||
assert.False(t, rl.Allow("k"))
|
||||
|
||||
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 0, Burst: 0, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
|
||||
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: -1, Burst: 5, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
|
||||
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 60, Burst: -1, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
|
||||
|
||||
rl.Reset("k")
|
||||
assert.True(t, rl.Allow("k"))
|
||||
assert.False(t, rl.Allow("k"), "burst should still be 1 — invalid UpdateConfig calls were ignored")
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_ConcurrentAllowAndUpdate(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 600,
|
||||
Burst: 10,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
stop := make(chan struct{})
|
||||
|
||||
for i := 0; i < 8; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
key := fmt.Sprintf("k%d", id)
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
rl.Allow(key)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 200; i++ {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
rl.UpdateConfig(&RateLimiterConfig{
|
||||
RequestsPerMinute: float64(30 + (i % 90)),
|
||||
Burst: 1 + (i % 20),
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
rl.SetEnabled(i%2 == 0)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
close(stop)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestRateLimiterConfigFromEnv(t *testing.T) {
|
||||
t.Setenv(RateLimitingEnabledEnv, "true")
|
||||
t.Setenv(RateLimitingRPMEnv, "42")
|
||||
t.Setenv(RateLimitingBurstEnv, "7")
|
||||
|
||||
cfg, enabled := RateLimiterConfigFromEnv()
|
||||
assert.True(t, enabled)
|
||||
assert.Equal(t, float64(42), cfg.RequestsPerMinute)
|
||||
assert.Equal(t, 7, cfg.Burst)
|
||||
|
||||
t.Setenv(RateLimitingEnabledEnv, "false")
|
||||
_, enabled = RateLimiterConfigFromEnv()
|
||||
assert.False(t, enabled)
|
||||
|
||||
t.Setenv(RateLimitingEnabledEnv, "")
|
||||
t.Setenv(RateLimitingRPMEnv, "")
|
||||
t.Setenv(RateLimitingBurstEnv, "")
|
||||
cfg, enabled = RateLimiterConfigFromEnv()
|
||||
assert.False(t, enabled)
|
||||
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute)
|
||||
assert.Equal(t, defaultAPIBurst, cfg.Burst)
|
||||
|
||||
t.Setenv(RateLimitingRPMEnv, "0")
|
||||
t.Setenv(RateLimitingBurstEnv, "-5")
|
||||
cfg, _ = RateLimiterConfigFromEnv()
|
||||
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute, "non-positive rpm must fall back to default")
|
||||
assert.Equal(t, defaultAPIBurst, cfg.Burst, "non-positive burst must fall back to default")
|
||||
}
|
||||
|
||||
@@ -135,7 +135,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -264,7 +264,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
|
||||
@@ -267,8 +267,8 @@ func Test_SyncProtocol(t *testing.T) {
|
||||
}
|
||||
|
||||
// expired peers come separately.
|
||||
if len(networkMap.GetOfflinePeers()) != 1 {
|
||||
t.Fatal("expecting SyncResponse to have NetworkMap with 1 offline peer")
|
||||
if len(networkMap.GetOfflinePeers()) != 2 {
|
||||
t.Fatal("expecting SyncResponse to have NetworkMap with 2 offline peer")
|
||||
}
|
||||
|
||||
expiredPeerPubKey := "RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4="
|
||||
|
||||
@@ -1405,6 +1405,10 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID
|
||||
|
||||
var peers []*nbpeer.Peer
|
||||
for _, peer := range peersWithExpiry {
|
||||
if peer.Status.LoginExpired {
|
||||
continue
|
||||
}
|
||||
|
||||
expired, _ := peer.LoginExpired(settings.PeerLoginExpiration)
|
||||
if expired {
|
||||
peers = append(peers, peer)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
_ "embed"
|
||||
|
||||
"github.com/rs/xid"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
@@ -46,25 +47,40 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
var isUpdate = policy.ID != ""
|
||||
var updateAccountPeers bool
|
||||
var action = activity.PolicyAdded
|
||||
var unchanged bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validatePolicy(ctx, transaction, accountID, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate)
|
||||
existingPolicy, err := validatePolicy(ctx, transaction, accountID, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
saveFunc := transaction.CreatePolicy
|
||||
if isUpdate {
|
||||
action = activity.PolicyUpdated
|
||||
saveFunc = transaction.SavePolicy
|
||||
}
|
||||
if policy.Equal(existingPolicy) {
|
||||
logrus.WithContext(ctx).Tracef("policy update skipped because equal to stored one - policy id %s", policy.ID)
|
||||
unchanged = true
|
||||
return nil
|
||||
}
|
||||
|
||||
if err = saveFunc(ctx, policy); err != nil {
|
||||
return err
|
||||
action = activity.PolicyUpdated
|
||||
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeersWithExisting(ctx, transaction, policy, existingPolicy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.SavePolicy(ctx, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.CreatePolicy(ctx, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
@@ -73,6 +89,10 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if unchanged {
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
@@ -101,7 +121,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false)
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -138,34 +158,37 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
|
||||
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
|
||||
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) {
|
||||
if isUpdate {
|
||||
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if !policy.Enabled && !existingPolicy.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, rule := range existingPolicy.Rules {
|
||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
// arePolicyChangesAffectPeers checks if a policy (being created or deleted) will affect any associated peers.
|
||||
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, policy *types.Policy) (bool, error) {
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
|
||||
}
|
||||
|
||||
func arePolicyChangesAffectPeersWithExisting(ctx context.Context, transaction store.Store, policy *types.Policy, existingPolicy *types.Policy) (bool, error) {
|
||||
if !policy.Enabled && !existingPolicy.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, rule := range existingPolicy.Rules {
|
||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||
return true, nil
|
||||
@@ -175,12 +198,15 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
|
||||
}
|
||||
|
||||
// validatePolicy validates the policy and its rules.
|
||||
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
|
||||
// validatePolicy validates the policy and its rules. For updates it returns
|
||||
// the existing policy loaded from the store so callers can avoid a second read.
|
||||
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) (*types.Policy, error) {
|
||||
var existingPolicy *types.Policy
|
||||
if policy.ID != "" {
|
||||
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||
var err error
|
||||
existingPolicy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: Refactor to support multiple rules per policy
|
||||
@@ -191,7 +217,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.ID != "" && !existingRuleIDs[rule.ID] {
|
||||
return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID)
|
||||
return nil, status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -201,12 +227,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
|
||||
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups())
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, rule := range policy.Rules {
|
||||
@@ -225,7 +251,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
|
||||
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
|
||||
}
|
||||
|
||||
return nil
|
||||
return existingPolicy, nil
|
||||
}
|
||||
|
||||
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
|
||||
|
||||
@@ -3310,7 +3310,7 @@ func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStreng
|
||||
|
||||
var peers []*nbpeer.Peer
|
||||
result := tx.
|
||||
Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
|
||||
Where("login_expiration_enabled = ? AND peer_status_login_expired != ? AND user_id IS NOT NULL AND user_id != ''", true, true).
|
||||
Find(&peers, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peers with expiration from the store: %s", result.Error)
|
||||
|
||||
@@ -2729,7 +2729,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) {
|
||||
{
|
||||
name: "should retrieve peers for an existing account ID",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectedCount: 4,
|
||||
expectedCount: 5,
|
||||
},
|
||||
{
|
||||
name: "should return no peers for a non-existing account ID",
|
||||
@@ -2751,7 +2751,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) {
|
||||
name: "should filter peers by partial name",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
nameFilter: "host",
|
||||
expectedCount: 3,
|
||||
expectedCount: 4,
|
||||
},
|
||||
{
|
||||
name: "should filter peers by ip",
|
||||
@@ -2777,14 +2777,16 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
expectedCount int
|
||||
name string
|
||||
accountID string
|
||||
expectedCount int
|
||||
expectedPeerIDs []string
|
||||
}{
|
||||
{
|
||||
name: "should retrieve peers with expiration for an existing account ID",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectedCount: 1,
|
||||
name: "should retrieve only non-expired peers with expiration enabled",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectedCount: 1,
|
||||
expectedPeerIDs: []string{"notexpired01"},
|
||||
},
|
||||
{
|
||||
name: "should return no peers with expiration for a non-existing account ID",
|
||||
@@ -2803,10 +2805,30 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) {
|
||||
peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, tt.accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peers, tt.expectedCount)
|
||||
for i, peer := range peers {
|
||||
assert.Equal(t, tt.expectedPeerIDs[i], peer.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountPeersWithExpiration_ExcludesAlreadyExpired(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the already-expired peer (cg05lnblo1hkg2j514p0) is not returned
|
||||
for _, peer := range peers {
|
||||
assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should not be returned")
|
||||
assert.False(t, peer.Status.LoginExpired, "returned peers should not have LoginExpired set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
@@ -2887,7 +2909,7 @@ func TestSqlStore_GetUserPeers(t *testing.T) {
|
||||
name: "should retrieve peers for another valid account ID and user ID",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
userID: "edafee4e-63fb-11ec-90d6-0242ac120003",
|
||||
expectedCount: 2,
|
||||
expectedCount: 3,
|
||||
},
|
||||
{
|
||||
name: "should return no peers for existing account ID with empty user ID",
|
||||
|
||||
@@ -193,20 +193,12 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
|
||||
}
|
||||
})
|
||||
|
||||
h.ServeHTTP(w, r.WithContext(ctx))
|
||||
// Hold on to req so auth's in-place ctx update is visible after ServeHTTP.
|
||||
req := r.WithContext(ctx)
|
||||
h.ServeHTTP(w, req)
|
||||
close(handlerDone)
|
||||
|
||||
userAuth, err := nbContext.GetUserAuthFromContext(r.Context())
|
||||
if err == nil {
|
||||
if userAuth.AccountId != "" {
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, userAuth.AccountId)
|
||||
}
|
||||
if userAuth.UserId != "" {
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.UserIDKey, userAuth.UserId)
|
||||
}
|
||||
}
|
||||
ctx = req.Context()
|
||||
|
||||
if w.Status() > 399 {
|
||||
log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status())
|
||||
|
||||
@@ -31,6 +31,7 @@ INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-3465300
|
||||
INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','nVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HX=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('notexpired01','bf1c8084-ba50-4ce7-9439-34653001fc3b','oVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HY=','','"100.64.117.98"','activehost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'activehost','activehost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,1,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
|
||||
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
|
||||
INSERT INTO installations VALUES(1,'');
|
||||
|
||||
@@ -93,6 +93,44 @@ func (p *Policy) Copy() *Policy {
|
||||
return c
|
||||
}
|
||||
|
||||
func (p *Policy) Equal(other *Policy) bool {
|
||||
if p == nil || other == nil {
|
||||
return p == other
|
||||
}
|
||||
|
||||
if p.ID != other.ID ||
|
||||
p.AccountID != other.AccountID ||
|
||||
p.Name != other.Name ||
|
||||
p.Description != other.Description ||
|
||||
p.Enabled != other.Enabled {
|
||||
return false
|
||||
}
|
||||
|
||||
if !stringSlicesEqualUnordered(p.SourcePostureChecks, other.SourcePostureChecks) {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(p.Rules) != len(other.Rules) {
|
||||
return false
|
||||
}
|
||||
|
||||
otherRules := make(map[string]*PolicyRule, len(other.Rules))
|
||||
for _, r := range other.Rules {
|
||||
otherRules[r.ID] = r
|
||||
}
|
||||
for _, r := range p.Rules {
|
||||
otherRule, ok := otherRules[r.ID]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if !r.Equal(otherRule) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// EventMeta returns activity event meta related to this policy
|
||||
func (p *Policy) EventMeta() map[string]any {
|
||||
return map[string]any{"name": p.Name}
|
||||
|
||||
193
management/server/types/policy_test.go
Normal file
193
management/server/types/policy_test.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPolicyEqual_SameRulesDifferentOrder(t *testing.T) {
|
||||
a := &Policy{
|
||||
ID: "pol1",
|
||||
AccountID: "acc1",
|
||||
Name: "test",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{ID: "r1", PolicyID: "pol1", Ports: []string{"80"}},
|
||||
{ID: "r2", PolicyID: "pol1", Ports: []string{"443"}},
|
||||
},
|
||||
}
|
||||
b := &Policy{
|
||||
ID: "pol1",
|
||||
AccountID: "acc1",
|
||||
Name: "test",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{ID: "r2", PolicyID: "pol1", Ports: []string{"443"}},
|
||||
{ID: "r1", PolicyID: "pol1", Ports: []string{"80"}},
|
||||
},
|
||||
}
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyEqual_DifferentRules(t *testing.T) {
|
||||
a := &Policy{
|
||||
ID: "pol1",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{ID: "r1", PolicyID: "pol1", Ports: []string{"80"}},
|
||||
},
|
||||
}
|
||||
b := &Policy{
|
||||
ID: "pol1",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{ID: "r1", PolicyID: "pol1", Ports: []string{"443"}},
|
||||
},
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyEqual_DifferentRuleCount(t *testing.T) {
|
||||
a := &Policy{
|
||||
ID: "pol1",
|
||||
Rules: []*PolicyRule{
|
||||
{ID: "r1", PolicyID: "pol1"},
|
||||
},
|
||||
}
|
||||
b := &Policy{
|
||||
ID: "pol1",
|
||||
Rules: []*PolicyRule{
|
||||
{ID: "r1", PolicyID: "pol1"},
|
||||
{ID: "r2", PolicyID: "pol1"},
|
||||
},
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyEqual_PostureChecksDifferentOrder(t *testing.T) {
|
||||
a := &Policy{
|
||||
ID: "pol1",
|
||||
SourcePostureChecks: []string{"pc3", "pc1", "pc2"},
|
||||
}
|
||||
b := &Policy{
|
||||
ID: "pol1",
|
||||
SourcePostureChecks: []string{"pc1", "pc2", "pc3"},
|
||||
}
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyEqual_DifferentPostureChecks(t *testing.T) {
|
||||
a := &Policy{
|
||||
ID: "pol1",
|
||||
SourcePostureChecks: []string{"pc1", "pc2"},
|
||||
}
|
||||
b := &Policy{
|
||||
ID: "pol1",
|
||||
SourcePostureChecks: []string{"pc1", "pc3"},
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyEqual_DifferentScalarFields(t *testing.T) {
|
||||
base := Policy{
|
||||
ID: "pol1",
|
||||
AccountID: "acc1",
|
||||
Name: "test",
|
||||
Description: "desc",
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
other := base
|
||||
other.Name = "changed"
|
||||
assert.False(t, base.Equal(&other))
|
||||
|
||||
other = base
|
||||
other.Enabled = false
|
||||
assert.False(t, base.Equal(&other))
|
||||
|
||||
other = base
|
||||
other.Description = "changed"
|
||||
assert.False(t, base.Equal(&other))
|
||||
}
|
||||
|
||||
func TestPolicyEqual_NilCases(t *testing.T) {
|
||||
var a *Policy
|
||||
var b *Policy
|
||||
assert.True(t, a.Equal(b))
|
||||
|
||||
a = &Policy{ID: "pol1"}
|
||||
assert.False(t, a.Equal(nil))
|
||||
}
|
||||
|
||||
func TestPolicyEqual_RulesMismatchByID(t *testing.T) {
|
||||
a := &Policy{
|
||||
ID: "pol1",
|
||||
Rules: []*PolicyRule{
|
||||
{ID: "r1", PolicyID: "pol1"},
|
||||
},
|
||||
}
|
||||
b := &Policy{
|
||||
ID: "pol1",
|
||||
Rules: []*PolicyRule{
|
||||
{ID: "r2", PolicyID: "pol1"},
|
||||
},
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyEqual_FullScenario(t *testing.T) {
|
||||
a := &Policy{
|
||||
ID: "pol1",
|
||||
AccountID: "acc1",
|
||||
Name: "Web Access",
|
||||
Description: "Allow web access",
|
||||
Enabled: true,
|
||||
SourcePostureChecks: []string{"pc2", "pc1"},
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "r1",
|
||||
PolicyID: "pol1",
|
||||
Name: "HTTP",
|
||||
Enabled: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Bidirectional: true,
|
||||
Sources: []string{"g2", "g1"},
|
||||
Destinations: []string{"g4", "g3"},
|
||||
Ports: []string{"443", "80", "8080"},
|
||||
PortRanges: []RulePortRange{
|
||||
{Start: 8000, End: 9000},
|
||||
{Start: 80, End: 80},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
b := &Policy{
|
||||
ID: "pol1",
|
||||
AccountID: "acc1",
|
||||
Name: "Web Access",
|
||||
Description: "Allow web access",
|
||||
Enabled: true,
|
||||
SourcePostureChecks: []string{"pc1", "pc2"},
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "r1",
|
||||
PolicyID: "pol1",
|
||||
Name: "HTTP",
|
||||
Enabled: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Bidirectional: true,
|
||||
Sources: []string{"g1", "g2"},
|
||||
Destinations: []string{"g3", "g4"},
|
||||
Ports: []string{"80", "8080", "443"},
|
||||
PortRanges: []RulePortRange{
|
||||
{Start: 80, End: 80},
|
||||
{Start: 8000, End: 9000},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"slices"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -118,3 +120,106 @@ func (pm *PolicyRule) Copy() *PolicyRule {
|
||||
}
|
||||
return rule
|
||||
}
|
||||
|
||||
func (pm *PolicyRule) Equal(other *PolicyRule) bool {
|
||||
if pm == nil || other == nil {
|
||||
return pm == other
|
||||
}
|
||||
|
||||
if pm.ID != other.ID ||
|
||||
pm.PolicyID != other.PolicyID ||
|
||||
pm.Name != other.Name ||
|
||||
pm.Description != other.Description ||
|
||||
pm.Enabled != other.Enabled ||
|
||||
pm.Action != other.Action ||
|
||||
pm.Bidirectional != other.Bidirectional ||
|
||||
pm.Protocol != other.Protocol ||
|
||||
pm.SourceResource != other.SourceResource ||
|
||||
pm.DestinationResource != other.DestinationResource ||
|
||||
pm.AuthorizedUser != other.AuthorizedUser {
|
||||
return false
|
||||
}
|
||||
|
||||
if !stringSlicesEqualUnordered(pm.Sources, other.Sources) {
|
||||
return false
|
||||
}
|
||||
if !stringSlicesEqualUnordered(pm.Destinations, other.Destinations) {
|
||||
return false
|
||||
}
|
||||
if !stringSlicesEqualUnordered(pm.Ports, other.Ports) {
|
||||
return false
|
||||
}
|
||||
if !portRangeSlicesEqualUnordered(pm.PortRanges, other.PortRanges) {
|
||||
return false
|
||||
}
|
||||
if !authorizedGroupsEqual(pm.AuthorizedGroups, other.AuthorizedGroups) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func stringSlicesEqualUnordered(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
if len(a) == 0 {
|
||||
return true
|
||||
}
|
||||
sorted1 := make([]string, len(a))
|
||||
sorted2 := make([]string, len(b))
|
||||
copy(sorted1, a)
|
||||
copy(sorted2, b)
|
||||
slices.Sort(sorted1)
|
||||
slices.Sort(sorted2)
|
||||
return slices.Equal(sorted1, sorted2)
|
||||
}
|
||||
|
||||
func portRangeSlicesEqualUnordered(a, b []RulePortRange) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
if len(a) == 0 {
|
||||
return true
|
||||
}
|
||||
cmp := func(x, y RulePortRange) int {
|
||||
if x.Start != y.Start {
|
||||
if x.Start < y.Start {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
if x.End != y.End {
|
||||
if x.End < y.End {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
sorted1 := make([]RulePortRange, len(a))
|
||||
sorted2 := make([]RulePortRange, len(b))
|
||||
copy(sorted1, a)
|
||||
copy(sorted2, b)
|
||||
slices.SortFunc(sorted1, cmp)
|
||||
slices.SortFunc(sorted2, cmp)
|
||||
return slices.EqualFunc(sorted1, sorted2, func(x, y RulePortRange) bool {
|
||||
return x.Start == y.Start && x.End == y.End
|
||||
})
|
||||
}
|
||||
|
||||
func authorizedGroupsEqual(a, b map[string][]string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for k, va := range a {
|
||||
vb, ok := b[k]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if !stringSlicesEqualUnordered(va, vb) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
194
management/server/types/policyrule_test.go
Normal file
194
management/server/types/policyrule_test.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPolicyRuleEqual_SamePortsDifferentOrder(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Ports: []string{"443", "80", "22"},
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Ports: []string{"22", "443", "80"},
|
||||
}
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_DifferentPorts(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Ports: []string{"443", "80"},
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Ports: []string{"443", "22"},
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_SourcesDestinationsDifferentOrder(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Sources: []string{"g1", "g2", "g3"},
|
||||
Destinations: []string{"g4", "g5"},
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Sources: []string{"g3", "g1", "g2"},
|
||||
Destinations: []string{"g5", "g4"},
|
||||
}
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_DifferentSources(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Sources: []string{"g1", "g2"},
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Sources: []string{"g1", "g3"},
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_PortRangesDifferentOrder(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
PortRanges: []RulePortRange{
|
||||
{Start: 8000, End: 9000},
|
||||
{Start: 80, End: 80},
|
||||
},
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
PortRanges: []RulePortRange{
|
||||
{Start: 80, End: 80},
|
||||
{Start: 8000, End: 9000},
|
||||
},
|
||||
}
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_DifferentPortRanges(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
PortRanges: []RulePortRange{
|
||||
{Start: 80, End: 80},
|
||||
},
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
PortRanges: []RulePortRange{
|
||||
{Start: 80, End: 443},
|
||||
},
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_AuthorizedGroupsDifferentValueOrder(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
AuthorizedGroups: map[string][]string{
|
||||
"g1": {"u1", "u2", "u3"},
|
||||
},
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
AuthorizedGroups: map[string][]string{
|
||||
"g1": {"u3", "u1", "u2"},
|
||||
},
|
||||
}
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_DifferentAuthorizedGroups(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
AuthorizedGroups: map[string][]string{
|
||||
"g1": {"u1"},
|
||||
},
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
AuthorizedGroups: map[string][]string{
|
||||
"g2": {"u1"},
|
||||
},
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_DifferentScalarFields(t *testing.T) {
|
||||
base := PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Name: "test",
|
||||
Description: "desc",
|
||||
Enabled: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Bidirectional: true,
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
}
|
||||
|
||||
other := base
|
||||
other.Name = "changed"
|
||||
assert.False(t, base.Equal(&other))
|
||||
|
||||
other = base
|
||||
other.Enabled = false
|
||||
assert.False(t, base.Equal(&other))
|
||||
|
||||
other = base
|
||||
other.Action = PolicyTrafficActionDrop
|
||||
assert.False(t, base.Equal(&other))
|
||||
|
||||
other = base
|
||||
other.Protocol = PolicyRuleProtocolUDP
|
||||
assert.False(t, base.Equal(&other))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_NilCases(t *testing.T) {
|
||||
var a *PolicyRule
|
||||
var b *PolicyRule
|
||||
assert.True(t, a.Equal(b))
|
||||
|
||||
a = &PolicyRule{ID: "rule1"}
|
||||
assert.False(t, a.Equal(nil))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_EmptySlices(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Ports: []string{},
|
||||
Sources: nil,
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Ports: nil,
|
||||
Sources: []string{},
|
||||
}
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
|
||||
@@ -2,7 +2,12 @@ package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
@@ -99,6 +104,27 @@ var debugStopCmd = &cobra.Command{
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var debugCaptureCmd = &cobra.Command{
|
||||
Use: "capture <account-id> [filter expression]",
|
||||
Short: "Capture packets on a client's WireGuard interface",
|
||||
Long: `Captures decrypted packets flowing through a client's WireGuard interface.
|
||||
|
||||
Default output is human-readable text. Use --pcap or --output for pcap binary.
|
||||
Filter arguments after the account ID use BPF-like syntax.
|
||||
|
||||
Examples:
|
||||
netbird-proxy debug capture <account-id>
|
||||
netbird-proxy debug capture <account-id> --duration 1m host 10.0.0.1
|
||||
netbird-proxy debug capture <account-id> host 10.0.0.1 and tcp port 443
|
||||
netbird-proxy debug capture <account-id> not port 22
|
||||
netbird-proxy debug capture <account-id> -o capture.pcap
|
||||
netbird-proxy debug capture <account-id> --pcap | tcpdump -r - -n
|
||||
netbird-proxy debug capture <account-id> --pcap | tshark -r -`,
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: runDebugCapture,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
func init() {
|
||||
debugCmd.PersistentFlags().StringVar(&debugAddr, "addr", envStringOrDefault("NB_PROXY_DEBUG_ADDRESS", "localhost:8444"), "Debug endpoint address")
|
||||
debugCmd.PersistentFlags().BoolVar(&jsonOutput, "json", false, "Output JSON instead of pretty format")
|
||||
@@ -110,6 +136,12 @@ func init() {
|
||||
|
||||
debugPingCmd.Flags().StringVar(&pingTimeout, "timeout", "", "Ping timeout (e.g., 10s)")
|
||||
|
||||
debugCaptureCmd.Flags().DurationP("duration", "d", 0, "Capture duration (0 = server default)")
|
||||
debugCaptureCmd.Flags().Bool("pcap", false, "Force pcap binary output (default when --output is set)")
|
||||
debugCaptureCmd.Flags().BoolP("verbose", "v", false, "Show seq/ack, TTL, window, total length (text mode)")
|
||||
debugCaptureCmd.Flags().Bool("ascii", false, "Print payload as ASCII after each packet (text mode)")
|
||||
debugCaptureCmd.Flags().StringP("output", "o", "", "Write pcap to file instead of stdout")
|
||||
|
||||
debugCmd.AddCommand(debugHealthCmd)
|
||||
debugCmd.AddCommand(debugClientsCmd)
|
||||
debugCmd.AddCommand(debugStatusCmd)
|
||||
@@ -119,6 +151,7 @@ func init() {
|
||||
debugCmd.AddCommand(debugLogCmd)
|
||||
debugCmd.AddCommand(debugStartCmd)
|
||||
debugCmd.AddCommand(debugStopCmd)
|
||||
debugCmd.AddCommand(debugCaptureCmd)
|
||||
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
}
|
||||
@@ -171,3 +204,84 @@ func runDebugStart(cmd *cobra.Command, args []string) error {
|
||||
func runDebugStop(cmd *cobra.Command, args []string) error {
|
||||
return getDebugClient(cmd).StopClient(cmd.Context(), args[0])
|
||||
}
|
||||
|
||||
func runDebugCapture(cmd *cobra.Command, args []string) error {
|
||||
duration, _ := cmd.Flags().GetDuration("duration")
|
||||
forcePcap, _ := cmd.Flags().GetBool("pcap")
|
||||
verbose, _ := cmd.Flags().GetBool("verbose")
|
||||
ascii, _ := cmd.Flags().GetBool("ascii")
|
||||
outPath, _ := cmd.Flags().GetString("output")
|
||||
|
||||
// Default to text. Use pcap when --pcap is set or --output is given.
|
||||
wantText := !forcePcap && outPath == ""
|
||||
|
||||
var filterExpr string
|
||||
if len(args) > 1 {
|
||||
filterExpr = strings.Join(args[1:], " ")
|
||||
}
|
||||
|
||||
ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM)
|
||||
defer cancel()
|
||||
|
||||
out, cleanup, err := captureOutputWriter(cmd, outPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
if wantText {
|
||||
cmd.PrintErrln("Capturing packets... Press Ctrl+C to stop.")
|
||||
} else {
|
||||
cmd.PrintErrln("Capturing packets (pcap)... Press Ctrl+C to stop.")
|
||||
}
|
||||
|
||||
var durationStr string
|
||||
if duration > 0 {
|
||||
durationStr = duration.String()
|
||||
}
|
||||
|
||||
err = getDebugClient(cmd).Capture(ctx, debug.CaptureOptions{
|
||||
AccountID: args[0],
|
||||
Duration: durationStr,
|
||||
FilterExpr: filterExpr,
|
||||
Text: wantText,
|
||||
Verbose: verbose,
|
||||
ASCII: ascii,
|
||||
Output: out,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PrintErrln("\nCapture finished.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// captureOutputWriter returns the writer and cleanup function for capture output.
|
||||
func captureOutputWriter(cmd *cobra.Command, outPath string) (out *os.File, cleanup func(), err error) {
|
||||
if outPath != "" {
|
||||
f, err := os.CreateTemp(filepath.Dir(outPath), filepath.Base(outPath)+".*.tmp")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create output file: %w", err)
|
||||
}
|
||||
tmpPath := f.Name()
|
||||
return f, func() {
|
||||
if err := f.Close(); err != nil {
|
||||
cmd.PrintErrf("close output file: %v\n", err)
|
||||
}
|
||||
if fi, err := os.Stat(tmpPath); err == nil && fi.Size() > 0 {
|
||||
if err := os.Rename(tmpPath, outPath); err != nil {
|
||||
cmd.PrintErrf("rename output file: %v\n", err)
|
||||
} else {
|
||||
cmd.PrintErrf("Wrote %s\n", outPath)
|
||||
}
|
||||
} else {
|
||||
os.Remove(tmpPath)
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
return os.Stdout, func() {
|
||||
// no cleanup needed for stdout
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -433,6 +433,7 @@ func setSessionCookie(w http.ResponseWriter, token string, expiration time.Durat
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: auth.SessionCookieName,
|
||||
Value: token,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
|
||||
@@ -391,6 +391,15 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
|
||||
assert.Equal(t, http.SameSiteLaxMode, sessionCookie.SameSite)
|
||||
}
|
||||
|
||||
func TestSetSessionCookieHasRootPath(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
setSessionCookie(w, "test-token", time.Hour)
|
||||
|
||||
cookies := w.Result().Cookies()
|
||||
require.Len(t, cookies, 1)
|
||||
assert.Equal(t, "/", cookies[0].Path, "session cookie must be scoped to root so it applies to all paths")
|
||||
}
|
||||
|
||||
func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
@@ -310,6 +310,76 @@ func (c *Client) printError(data map[string]any) {
|
||||
}
|
||||
}
|
||||
|
||||
// CaptureOptions configures a capture request.
|
||||
type CaptureOptions struct {
|
||||
AccountID string
|
||||
Duration string
|
||||
FilterExpr string
|
||||
Text bool
|
||||
Verbose bool
|
||||
ASCII bool
|
||||
Output io.Writer
|
||||
}
|
||||
|
||||
// Capture streams a packet capture from the debug endpoint. The response body
|
||||
// (pcap or text) is written directly to opts.Output until the server closes the
|
||||
// connection or the context is cancelled.
|
||||
func (c *Client) Capture(ctx context.Context, opts CaptureOptions) error {
|
||||
if opts.AccountID == "" {
|
||||
return fmt.Errorf("account ID is required")
|
||||
}
|
||||
if opts.Output == nil {
|
||||
return fmt.Errorf("output writer is required")
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
if opts.Duration != "" {
|
||||
params.Set("duration", opts.Duration)
|
||||
}
|
||||
if opts.FilterExpr != "" {
|
||||
params.Set("filter", opts.FilterExpr)
|
||||
}
|
||||
if opts.Text {
|
||||
params.Set("format", "text")
|
||||
}
|
||||
if opts.Verbose {
|
||||
params.Set("verbose", "true")
|
||||
}
|
||||
if opts.ASCII {
|
||||
params.Set("ascii", "true")
|
||||
}
|
||||
|
||||
path := fmt.Sprintf("/debug/clients/%s/capture", url.PathEscape(opts.AccountID))
|
||||
if len(params) > 0 {
|
||||
path += "?" + params.Encode()
|
||||
}
|
||||
|
||||
fullURL := c.baseURL + path
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
// Use a separate client without timeout since captures stream for their full duration.
|
||||
httpClient := &http.Client{}
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("server error (%d): %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
_, err = io.Copy(opts.Output, resp.Body)
|
||||
if err != nil && ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) fetchAndPrint(ctx context.Context, path string, printer func(map[string]any)) error {
|
||||
data, raw, err := c.fetch(ctx, path)
|
||||
if err != nil {
|
||||
|
||||
@@ -174,6 +174,8 @@ func (h *Handler) handleClientRoutes(w http.ResponseWriter, r *http.Request, pat
|
||||
h.handleClientStart(w, r, accountID)
|
||||
case "stop":
|
||||
h.handleClientStop(w, r, accountID)
|
||||
case "capture":
|
||||
h.handleCapture(w, r, accountID)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
@@ -632,6 +634,81 @@ func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accou
|
||||
})
|
||||
}
|
||||
|
||||
const maxCaptureDuration = 30 * time.Minute
|
||||
|
||||
// handleCapture streams a pcap or text packet capture for the given client.
|
||||
//
|
||||
// Query params:
|
||||
//
|
||||
// duration: capture duration (0 or absent = max, capped at 30m)
|
||||
// format: "text" for human-readable output (default: pcap)
|
||||
// filter: BPF-like filter expression (e.g. "host 10.0.0.1 and tcp port 443")
|
||||
func (h *Handler) handleCapture(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
http.Error(w, "client not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
duration := maxCaptureDuration
|
||||
if durationStr := r.URL.Query().Get("duration"); durationStr != "" {
|
||||
d, err := time.ParseDuration(durationStr)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid duration: "+err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if d < 0 {
|
||||
http.Error(w, "duration must not be negative", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if d > 0 {
|
||||
duration = min(d, maxCaptureDuration)
|
||||
}
|
||||
}
|
||||
|
||||
filter := r.URL.Query().Get("filter")
|
||||
wantText := r.URL.Query().Get("format") == "text"
|
||||
verbose := r.URL.Query().Get("verbose") == "true"
|
||||
ascii := r.URL.Query().Get("ascii") == "true"
|
||||
|
||||
opts := nbembed.CaptureOptions{Filter: filter, Verbose: verbose, ASCII: ascii}
|
||||
if wantText {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
opts.TextOutput = w
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/vnd.tcpdump.pcap")
|
||||
w.Header().Set("Content-Disposition",
|
||||
fmt.Sprintf("attachment; filename=capture-%s.pcap", accountID))
|
||||
opts.Output = w
|
||||
}
|
||||
|
||||
cs, err := client.StartCapture(opts)
|
||||
if err != nil {
|
||||
http.Error(w, "start capture: "+err.Error(), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
defer cs.Stop()
|
||||
|
||||
// Flush headers after setup succeeds so errors above can still set status codes.
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
timer := time.NewTimer(duration)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
cs.Stop()
|
||||
|
||||
stats := cs.Stats()
|
||||
h.logger.Infof("capture for %s finished: %d packets, %d bytes, %d dropped",
|
||||
accountID, stats.Packets, stats.Bytes, stats.Dropped)
|
||||
}
|
||||
|
||||
func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON bool) {
|
||||
if !wantJSON {
|
||||
http.Redirect(w, r, "/debug", http.StatusSeeOther)
|
||||
|
||||
@@ -30,6 +30,8 @@ import (
|
||||
|
||||
const ConnectTimeout = 10 * time.Second
|
||||
|
||||
const healthCheckTimeout = 5 * time.Second
|
||||
|
||||
const (
|
||||
// EnvMaxRecvMsgSize overrides the default gRPC max receive message size (4 MB)
|
||||
// for the management client connection. Value is in bytes.
|
||||
@@ -532,7 +534,7 @@ func (c *GrpcClient) IsHealthy() bool {
|
||||
case connectivity.Ready:
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second)
|
||||
ctx, cancel := context.WithTimeout(c.ctx, healthCheckTimeout)
|
||||
defer cancel()
|
||||
|
||||
_, err := c.realClient.GetServerKey(ctx, &proto.Empty{})
|
||||
|
||||
@@ -23,6 +23,8 @@ import (
|
||||
"github.com/netbirdio/netbird/util/wsproxy"
|
||||
)
|
||||
|
||||
const healthCheckTimeout = 5 * time.Second
|
||||
|
||||
// ConnStateNotifier is a wrapper interface of the status recorder
|
||||
type ConnStateNotifier interface {
|
||||
MarkSignalDisconnected(error)
|
||||
@@ -263,7 +265,7 @@ func (c *GrpcClient) IsHealthy() bool {
|
||||
case connectivity.Ready:
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second)
|
||||
ctx, cancel := context.WithTimeout(c.ctx, healthCheckTimeout)
|
||||
defer cancel()
|
||||
_, err := c.realClient.Send(ctx, &proto.EncryptedMessage{
|
||||
Key: c.key.PublicKey().String(),
|
||||
|
||||
199
util/capture/afpacket_linux.go
Normal file
199
util/capture/afpacket_linux.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package capture
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// htons converts a uint16 from host to network (big-endian) byte order.
|
||||
func htons(v uint16) uint16 {
|
||||
var buf [2]byte
|
||||
binary.BigEndian.PutUint16(buf[:], v)
|
||||
return binary.NativeEndian.Uint16(buf[:])
|
||||
}
|
||||
|
||||
// AFPacketCapture reads raw packets from a network interface using an
|
||||
// AF_PACKET socket. This is the kernel-mode fallback when FilteredDevice is
|
||||
// not available (kernel WireGuard). Linux only.
|
||||
//
|
||||
// It implements device.PacketCapture so it can be set on a Session, but it
|
||||
// drives its own read loop rather than being called from FilteredDevice.
|
||||
// Call Start to begin and Stop to end.
|
||||
type AFPacketCapture struct {
|
||||
ifaceName string
|
||||
sess *Session
|
||||
fd int
|
||||
mu sync.Mutex
|
||||
stopped chan struct{}
|
||||
started atomic.Bool
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
// NewAFPacketCapture creates a capture bound to the given interface.
|
||||
// The session receives packets via Offer.
|
||||
func NewAFPacketCapture(ifaceName string, sess *Session) *AFPacketCapture {
|
||||
return &AFPacketCapture{
|
||||
ifaceName: ifaceName,
|
||||
sess: sess,
|
||||
fd: -1,
|
||||
stopped: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start opens the AF_PACKET socket and begins reading packets.
|
||||
// Packets are fed to the session via Offer. Returns immediately;
|
||||
// the read loop runs in a goroutine.
|
||||
func (c *AFPacketCapture) Start() error {
|
||||
if c.sess == nil {
|
||||
return errors.New("nil capture session")
|
||||
}
|
||||
if !c.started.CompareAndSwap(false, true) {
|
||||
return errors.New("capture already started")
|
||||
}
|
||||
if c.closed.Load() {
|
||||
c.started.Store(false)
|
||||
return errors.New("cannot restart stopped capture")
|
||||
}
|
||||
|
||||
iface, err := net.InterfaceByName(c.ifaceName)
|
||||
if err != nil {
|
||||
c.started.Store(false)
|
||||
return fmt.Errorf("interface %s: %w", c.ifaceName, err)
|
||||
}
|
||||
|
||||
fd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_DGRAM|unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC, int(htons(unix.ETH_P_ALL)))
|
||||
if err != nil {
|
||||
c.started.Store(false)
|
||||
return fmt.Errorf("create AF_PACKET socket: %w", err)
|
||||
}
|
||||
|
||||
addr := &unix.SockaddrLinklayer{
|
||||
Protocol: htons(unix.ETH_P_ALL),
|
||||
Ifindex: iface.Index,
|
||||
}
|
||||
if err := unix.Bind(fd, addr); err != nil {
|
||||
unix.Close(fd)
|
||||
c.started.Store(false)
|
||||
return fmt.Errorf("bind to %s: %w", c.ifaceName, err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.fd = fd
|
||||
c.mu.Unlock()
|
||||
|
||||
go c.readLoop(fd)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop closes the socket and waits for the read loop to exit. Idempotent.
|
||||
func (c *AFPacketCapture) Stop() {
|
||||
if !c.closed.CompareAndSwap(false, true) {
|
||||
if c.started.Load() {
|
||||
<-c.stopped
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
fd := c.fd
|
||||
c.fd = -1
|
||||
c.mu.Unlock()
|
||||
|
||||
if fd >= 0 {
|
||||
unix.Close(fd)
|
||||
}
|
||||
|
||||
if c.started.Load() {
|
||||
<-c.stopped
|
||||
}
|
||||
}
|
||||
|
||||
func (c *AFPacketCapture) readLoop(fd int) {
|
||||
defer close(c.stopped)
|
||||
|
||||
buf := make([]byte, 65536)
|
||||
pollFds := []unix.PollFd{{Fd: int32(fd), Events: unix.POLLIN}}
|
||||
|
||||
for {
|
||||
if c.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := c.pollOnce(pollFds)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
c.recvAndOffer(fd, buf)
|
||||
}
|
||||
}
|
||||
|
||||
// pollOnce waits for data on the fd. Returns true if data is ready, false for timeout/retry.
|
||||
// Returns an error to signal the loop should exit.
|
||||
func (c *AFPacketCapture) pollOnce(pollFds []unix.PollFd) (bool, error) {
|
||||
n, err := unix.Poll(pollFds, 200)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EINTR) {
|
||||
return false, nil
|
||||
}
|
||||
if c.closed.Load() {
|
||||
return false, errors.New("closed")
|
||||
}
|
||||
log.Debugf("af_packet poll: %v", err)
|
||||
return false, err
|
||||
}
|
||||
if n == 0 {
|
||||
return false, nil
|
||||
}
|
||||
if pollFds[0].Revents&(unix.POLLERR|unix.POLLHUP|unix.POLLNVAL) != 0 {
|
||||
return false, errors.New("fd error")
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (c *AFPacketCapture) recvAndOffer(fd int, buf []byte) {
|
||||
nr, from, err := unix.Recvfrom(fd, buf, 0)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EINTR) {
|
||||
return
|
||||
}
|
||||
if !c.closed.Load() {
|
||||
log.Debugf("af_packet recvfrom: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if nr < 1 {
|
||||
return
|
||||
}
|
||||
|
||||
ver := buf[0] >> 4
|
||||
if ver != 4 && ver != 6 {
|
||||
return
|
||||
}
|
||||
|
||||
// The kernel sets Pkttype on AF_PACKET sockets:
|
||||
// PACKET_HOST(0) = addressed to us (inbound)
|
||||
// PACKET_OUTGOING(4) = sent by us (outbound)
|
||||
outbound := false
|
||||
if sa, ok := from.(*unix.SockaddrLinklayer); ok {
|
||||
outbound = sa.Pkttype == unix.PACKET_OUTGOING
|
||||
}
|
||||
c.sess.Offer(buf[:nr], outbound)
|
||||
}
|
||||
|
||||
// Offer satisfies device.PacketCapture but is unused: the AFPacketCapture
|
||||
// drives its own read loop. This exists only so the type signature is
|
||||
// compatible if someone tries to set it as a PacketCapture.
|
||||
func (c *AFPacketCapture) Offer([]byte, bool) {
|
||||
// unused: AFPacketCapture drives its own read loop
|
||||
}
|
||||
26
util/capture/afpacket_stub.go
Normal file
26
util/capture/afpacket_stub.go
Normal file
@@ -0,0 +1,26 @@
|
||||
//go:build !linux
|
||||
|
||||
package capture
|
||||
|
||||
import "errors"
|
||||
|
||||
// AFPacketCapture is not available on this platform.
|
||||
type AFPacketCapture struct{}
|
||||
|
||||
// NewAFPacketCapture returns nil on non-Linux platforms.
|
||||
func NewAFPacketCapture(string, *Session) *AFPacketCapture { return nil }
|
||||
|
||||
// Start returns an error on non-Linux platforms.
|
||||
func (c *AFPacketCapture) Start() error {
|
||||
return errors.New("AF_PACKET capture is only supported on Linux")
|
||||
}
|
||||
|
||||
// Stop is a no-op on non-Linux platforms.
|
||||
func (c *AFPacketCapture) Stop() {
|
||||
// no-op on non-Linux platforms
|
||||
}
|
||||
|
||||
// Offer is a no-op on non-Linux platforms.
|
||||
func (c *AFPacketCapture) Offer([]byte, bool) {
|
||||
// no-op on non-Linux platforms
|
||||
}
|
||||
59
util/capture/capture.go
Normal file
59
util/capture/capture.go
Normal file
@@ -0,0 +1,59 @@
|
||||
// Package capture provides userspace packet capture in pcap format.
|
||||
//
|
||||
// It taps decrypted WireGuard packets flowing through the FilteredDevice and
|
||||
// writes them as pcap (readable by tcpdump, tshark, Wireshark) or as
|
||||
// human-readable one-line-per-packet text.
|
||||
package capture
|
||||
|
||||
import "io"
|
||||
|
||||
// Direction indicates whether a packet is entering or leaving the host.
|
||||
type Direction uint8
|
||||
|
||||
const (
|
||||
// Inbound is a packet arriving from the network (FilteredDevice.Write path).
|
||||
Inbound Direction = iota
|
||||
// Outbound is a packet leaving the host (FilteredDevice.Read path).
|
||||
Outbound
|
||||
)
|
||||
|
||||
// String returns "IN" or "OUT".
|
||||
func (d Direction) String() string {
|
||||
if d == Outbound {
|
||||
return "OUT"
|
||||
}
|
||||
return "IN"
|
||||
}
|
||||
|
||||
const (
|
||||
protoICMP = 1
|
||||
protoTCP = 6
|
||||
protoUDP = 17
|
||||
protoICMPv6 = 58
|
||||
)
|
||||
|
||||
// Options configures a capture session.
|
||||
type Options struct {
|
||||
// Output receives pcap-formatted data. Nil disables pcap output.
|
||||
Output io.Writer
|
||||
// TextOutput receives human-readable packet summaries. Nil disables text output.
|
||||
TextOutput io.Writer
|
||||
// Matcher selects which packets to capture. Nil captures all.
|
||||
// Use ParseFilter("host 10.0.0.1 and tcp") or &Filter{...}.
|
||||
Matcher Matcher
|
||||
// Verbose adds seq/ack, TTL, window, total length to text output.
|
||||
Verbose bool
|
||||
// ASCII dumps transport payload as printable ASCII after each packet line.
|
||||
ASCII bool
|
||||
// SnapLen is the maximum bytes captured per packet. 0 means 65535.
|
||||
SnapLen uint32
|
||||
// BufSize is the internal channel buffer size. 0 means 256.
|
||||
BufSize int
|
||||
}
|
||||
|
||||
// Stats reports capture session counters.
|
||||
type Stats struct {
|
||||
Packets int64
|
||||
Bytes int64
|
||||
Dropped int64
|
||||
}
|
||||
528
util/capture/filter.go
Normal file
528
util/capture/filter.go
Normal file
@@ -0,0 +1,528 @@
|
||||
package capture
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Matcher tests whether a raw packet should be captured.
|
||||
type Matcher interface {
|
||||
Match(data []byte) bool
|
||||
}
|
||||
|
||||
// Filter selects packets by flat AND'd criteria. Useful for structured APIs
|
||||
// (query params, proto fields). Implements Matcher.
|
||||
type Filter struct {
|
||||
SrcIP netip.Addr
|
||||
DstIP netip.Addr
|
||||
Host netip.Addr
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
Port uint16
|
||||
Proto uint8
|
||||
}
|
||||
|
||||
// IsEmpty returns true if the filter has no criteria set.
|
||||
func (f *Filter) IsEmpty() bool {
|
||||
return !f.SrcIP.IsValid() && !f.DstIP.IsValid() && !f.Host.IsValid() &&
|
||||
f.SrcPort == 0 && f.DstPort == 0 && f.Port == 0 && f.Proto == 0
|
||||
}
|
||||
|
||||
// Match implements Matcher. All non-zero fields must match (AND).
|
||||
func (f *Filter) Match(data []byte) bool {
|
||||
if f.IsEmpty() {
|
||||
return true
|
||||
}
|
||||
info, ok := parsePacketInfo(data)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if f.Host.IsValid() && info.srcIP != f.Host && info.dstIP != f.Host {
|
||||
return false
|
||||
}
|
||||
if f.SrcIP.IsValid() && info.srcIP != f.SrcIP {
|
||||
return false
|
||||
}
|
||||
if f.DstIP.IsValid() && info.dstIP != f.DstIP {
|
||||
return false
|
||||
}
|
||||
if f.Proto != 0 && info.proto != f.Proto {
|
||||
return false
|
||||
}
|
||||
if f.Port != 0 && info.srcPort != f.Port && info.dstPort != f.Port {
|
||||
return false
|
||||
}
|
||||
if f.SrcPort != 0 && info.srcPort != f.SrcPort {
|
||||
return false
|
||||
}
|
||||
if f.DstPort != 0 && info.dstPort != f.DstPort {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// exprNode evaluates a filter condition against pre-parsed packet info.
|
||||
type exprNode func(info *packetInfo) bool
|
||||
|
||||
// exprMatcher wraps an expression tree. Parses the packet once, then walks the tree.
|
||||
type exprMatcher struct {
|
||||
root exprNode
|
||||
}
|
||||
|
||||
func (m *exprMatcher) Match(data []byte) bool {
|
||||
info, ok := parsePacketInfo(data)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return m.root(&info)
|
||||
}
|
||||
|
||||
func nodeAnd(a, b exprNode) exprNode {
|
||||
return func(info *packetInfo) bool { return a(info) && b(info) }
|
||||
}
|
||||
|
||||
func nodeOr(a, b exprNode) exprNode {
|
||||
return func(info *packetInfo) bool { return a(info) || b(info) }
|
||||
}
|
||||
|
||||
func nodeNot(n exprNode) exprNode {
|
||||
return func(info *packetInfo) bool { return !n(info) }
|
||||
}
|
||||
|
||||
func nodeHost(addr netip.Addr) exprNode {
|
||||
return func(info *packetInfo) bool { return info.srcIP == addr || info.dstIP == addr }
|
||||
}
|
||||
|
||||
func nodeSrcHost(addr netip.Addr) exprNode {
|
||||
return func(info *packetInfo) bool { return info.srcIP == addr }
|
||||
}
|
||||
|
||||
func nodeDstHost(addr netip.Addr) exprNode {
|
||||
return func(info *packetInfo) bool { return info.dstIP == addr }
|
||||
}
|
||||
|
||||
func nodePort(port uint16) exprNode {
|
||||
return func(info *packetInfo) bool { return info.srcPort == port || info.dstPort == port }
|
||||
}
|
||||
|
||||
func nodeSrcPort(port uint16) exprNode {
|
||||
return func(info *packetInfo) bool { return info.srcPort == port }
|
||||
}
|
||||
|
||||
func nodeDstPort(port uint16) exprNode {
|
||||
return func(info *packetInfo) bool { return info.dstPort == port }
|
||||
}
|
||||
|
||||
func nodeProto(proto uint8) exprNode {
|
||||
return func(info *packetInfo) bool { return info.proto == proto }
|
||||
}
|
||||
|
||||
func nodeFamily(family uint8) exprNode {
|
||||
return func(info *packetInfo) bool { return info.family == family }
|
||||
}
|
||||
|
||||
func nodeNet(prefix netip.Prefix) exprNode {
|
||||
return func(info *packetInfo) bool { return prefix.Contains(info.srcIP) || prefix.Contains(info.dstIP) }
|
||||
}
|
||||
|
||||
func nodeSrcNet(prefix netip.Prefix) exprNode {
|
||||
return func(info *packetInfo) bool { return prefix.Contains(info.srcIP) }
|
||||
}
|
||||
|
||||
func nodeDstNet(prefix netip.Prefix) exprNode {
|
||||
return func(info *packetInfo) bool { return prefix.Contains(info.dstIP) }
|
||||
}
|
||||
|
||||
// packetInfo holds parsed header fields for filtering and display.
|
||||
type packetInfo struct {
|
||||
family uint8
|
||||
srcIP netip.Addr
|
||||
dstIP netip.Addr
|
||||
proto uint8
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
hdrLen int
|
||||
}
|
||||
|
||||
func parsePacketInfo(data []byte) (packetInfo, bool) {
|
||||
if len(data) < 1 {
|
||||
return packetInfo{}, false
|
||||
}
|
||||
switch data[0] >> 4 {
|
||||
case 4:
|
||||
return parseIPv4Info(data)
|
||||
case 6:
|
||||
return parseIPv6Info(data)
|
||||
default:
|
||||
return packetInfo{}, false
|
||||
}
|
||||
}
|
||||
|
||||
func parseIPv4Info(data []byte) (packetInfo, bool) {
|
||||
if len(data) < 20 {
|
||||
return packetInfo{}, false
|
||||
}
|
||||
ihl := int(data[0]&0x0f) * 4
|
||||
if ihl < 20 || len(data) < ihl {
|
||||
return packetInfo{}, false
|
||||
}
|
||||
info := packetInfo{
|
||||
family: 4,
|
||||
srcIP: netip.AddrFrom4([4]byte{data[12], data[13], data[14], data[15]}),
|
||||
dstIP: netip.AddrFrom4([4]byte{data[16], data[17], data[18], data[19]}),
|
||||
proto: data[9],
|
||||
hdrLen: ihl,
|
||||
}
|
||||
if (info.proto == protoTCP || info.proto == protoUDP) && len(data) >= ihl+4 {
|
||||
info.srcPort = binary.BigEndian.Uint16(data[ihl:])
|
||||
info.dstPort = binary.BigEndian.Uint16(data[ihl+2:])
|
||||
}
|
||||
return info, true
|
||||
}
|
||||
|
||||
// parseIPv6Info parses the fixed IPv6 header. It reads the Next Header field
|
||||
// directly, so packets with extension headers (hop-by-hop, routing, fragment,
|
||||
// etc.) will report the extension type as the protocol rather than the final
|
||||
// transport protocol. This is acceptable for a debug capture tool.
|
||||
func parseIPv6Info(data []byte) (packetInfo, bool) {
|
||||
if len(data) < 40 {
|
||||
return packetInfo{}, false
|
||||
}
|
||||
var src, dst [16]byte
|
||||
copy(src[:], data[8:24])
|
||||
copy(dst[:], data[24:40])
|
||||
info := packetInfo{
|
||||
family: 6,
|
||||
srcIP: netip.AddrFrom16(src),
|
||||
dstIP: netip.AddrFrom16(dst),
|
||||
proto: data[6],
|
||||
hdrLen: 40,
|
||||
}
|
||||
if (info.proto == protoTCP || info.proto == protoUDP) && len(data) >= 44 {
|
||||
info.srcPort = binary.BigEndian.Uint16(data[40:])
|
||||
info.dstPort = binary.BigEndian.Uint16(data[42:])
|
||||
}
|
||||
return info, true
|
||||
}
|
||||
|
||||
// ParseFilter parses a BPF-like filter expression and returns a Matcher.
|
||||
// Returns nil Matcher for an empty expression (match all).
|
||||
//
|
||||
// Grammar (mirrors common tcpdump BPF syntax):
|
||||
//
|
||||
// orExpr = andExpr ("or" andExpr)*
|
||||
// andExpr = unary ("and" unary)*
|
||||
// unary = "not" unary | "(" orExpr ")" | term
|
||||
//
|
||||
// term = "host" IP | "src" target | "dst" target
|
||||
// | "port" NUM | "net" PREFIX
|
||||
// | "tcp" | "udp" | "icmp" | "icmp6"
|
||||
// | "ip" | "ip6" | "proto" NUM
|
||||
// target = "host" IP | "port" NUM | "net" PREFIX | IP
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// host 10.0.0.1 and tcp port 443
|
||||
// not port 22
|
||||
// (host 10.0.0.1 or host 10.0.0.2) and tcp
|
||||
// ip6 and icmp6
|
||||
// net 10.0.0.0/24
|
||||
// src host 10.0.0.1 or dst port 80
|
||||
func ParseFilter(expr string) (Matcher, error) {
|
||||
tokens := tokenize(expr)
|
||||
if len(tokens) == 0 {
|
||||
return nil, nil //nolint:nilnil // nil Matcher means "match all"
|
||||
}
|
||||
|
||||
p := &parser{tokens: tokens}
|
||||
node, err := p.parseOr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if p.pos < len(p.tokens) {
|
||||
return nil, fmt.Errorf("unexpected token %q at position %d", p.tokens[p.pos], p.pos)
|
||||
}
|
||||
return &exprMatcher{root: node}, nil
|
||||
}
|
||||
|
||||
func tokenize(expr string) []string {
|
||||
expr = strings.TrimSpace(expr)
|
||||
if expr == "" {
|
||||
return nil
|
||||
}
|
||||
// Split on whitespace but keep parens as separate tokens.
|
||||
var tokens []string
|
||||
for _, field := range strings.Fields(expr) {
|
||||
tokens = append(tokens, splitParens(field)...)
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
// splitParens splits "(foo)" into "(", "foo", ")".
|
||||
func splitParens(s string) []string {
|
||||
var out []string
|
||||
for strings.HasPrefix(s, "(") {
|
||||
out = append(out, "(")
|
||||
s = s[1:]
|
||||
}
|
||||
var trail []string
|
||||
for strings.HasSuffix(s, ")") {
|
||||
trail = append(trail, ")")
|
||||
s = s[:len(s)-1]
|
||||
}
|
||||
if s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
out = append(out, trail...)
|
||||
return out
|
||||
}
|
||||
|
||||
type parser struct {
|
||||
tokens []string
|
||||
pos int
|
||||
}
|
||||
|
||||
func (p *parser) peek() string {
|
||||
if p.pos >= len(p.tokens) {
|
||||
return ""
|
||||
}
|
||||
return strings.ToLower(p.tokens[p.pos])
|
||||
}
|
||||
|
||||
func (p *parser) next() string {
|
||||
tok := p.peek()
|
||||
if tok != "" {
|
||||
p.pos++
|
||||
}
|
||||
return tok
|
||||
}
|
||||
|
||||
func (p *parser) expect(tok string) error {
|
||||
got := p.next()
|
||||
if got != tok {
|
||||
return fmt.Errorf("expected %q, got %q", tok, got)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *parser) parseOr() (exprNode, error) {
|
||||
left, err := p.parseAnd()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for p.peek() == "or" {
|
||||
p.next()
|
||||
right, err := p.parseAnd()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
left = nodeOr(left, right)
|
||||
}
|
||||
return left, nil
|
||||
}
|
||||
|
||||
func (p *parser) parseAnd() (exprNode, error) {
|
||||
left, err := p.parseUnary()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for {
|
||||
tok := p.peek()
|
||||
if tok == "and" {
|
||||
p.next()
|
||||
right, err := p.parseUnary()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
left = nodeAnd(left, right)
|
||||
continue
|
||||
}
|
||||
// Implicit AND: two atoms without "and" between them.
|
||||
// Only if the next token starts an atom (not "or", ")", or EOF).
|
||||
if tok != "" && tok != "or" && tok != ")" {
|
||||
right, err := p.parseUnary()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
left = nodeAnd(left, right)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
return left, nil
|
||||
}
|
||||
|
||||
func (p *parser) parseUnary() (exprNode, error) {
|
||||
switch p.peek() {
|
||||
case "not":
|
||||
p.next()
|
||||
inner, err := p.parseUnary()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nodeNot(inner), nil
|
||||
case "(":
|
||||
p.next()
|
||||
inner, err := p.parseOr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := p.expect(")"); err != nil {
|
||||
return nil, fmt.Errorf("unclosed parenthesis")
|
||||
}
|
||||
return inner, nil
|
||||
default:
|
||||
return p.parseAtom()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *parser) parseAtom() (exprNode, error) {
|
||||
tok := p.next()
|
||||
if tok == "" {
|
||||
return nil, fmt.Errorf("unexpected end of expression")
|
||||
}
|
||||
|
||||
switch tok {
|
||||
case "host":
|
||||
addr, err := p.parseAddr()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("host: %w", err)
|
||||
}
|
||||
return nodeHost(addr), nil
|
||||
|
||||
case "port":
|
||||
port, err := p.parsePort()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("port: %w", err)
|
||||
}
|
||||
return nodePort(port), nil
|
||||
|
||||
case "net":
|
||||
prefix, err := p.parsePrefix()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("net: %w", err)
|
||||
}
|
||||
return nodeNet(prefix), nil
|
||||
|
||||
case "src":
|
||||
return p.parseDirTarget(true)
|
||||
|
||||
case "dst":
|
||||
return p.parseDirTarget(false)
|
||||
|
||||
case "tcp":
|
||||
return nodeProto(protoTCP), nil
|
||||
case "udp":
|
||||
return nodeProto(protoUDP), nil
|
||||
case "icmp":
|
||||
return nodeProto(protoICMP), nil
|
||||
case "icmp6":
|
||||
return nodeProto(protoICMPv6), nil
|
||||
case "ip":
|
||||
return nodeFamily(4), nil
|
||||
case "ip6":
|
||||
return nodeFamily(6), nil
|
||||
|
||||
case "proto":
|
||||
raw := p.next()
|
||||
if raw == "" {
|
||||
return nil, fmt.Errorf("proto: missing number")
|
||||
}
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err != nil || n < 0 || n > 255 {
|
||||
return nil, fmt.Errorf("proto: invalid number %q", raw)
|
||||
}
|
||||
return nodeProto(uint8(n)), nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown filter keyword %q", tok)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *parser) parseDirTarget(isSrc bool) (exprNode, error) {
|
||||
tok := p.peek()
|
||||
switch tok {
|
||||
case "host":
|
||||
p.next()
|
||||
addr, err := p.parseAddr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if isSrc {
|
||||
return nodeSrcHost(addr), nil
|
||||
}
|
||||
return nodeDstHost(addr), nil
|
||||
|
||||
case "port":
|
||||
p.next()
|
||||
port, err := p.parsePort()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if isSrc {
|
||||
return nodeSrcPort(port), nil
|
||||
}
|
||||
return nodeDstPort(port), nil
|
||||
|
||||
case "net":
|
||||
p.next()
|
||||
prefix, err := p.parsePrefix()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if isSrc {
|
||||
return nodeSrcNet(prefix), nil
|
||||
}
|
||||
return nodeDstNet(prefix), nil
|
||||
|
||||
default:
|
||||
// Try as bare IP: "src 10.0.0.1"
|
||||
addr, err := p.parseAddr()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expected host, port, net, or IP after src/dst, got %q", tok)
|
||||
}
|
||||
if isSrc {
|
||||
return nodeSrcHost(addr), nil
|
||||
}
|
||||
return nodeDstHost(addr), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *parser) parseAddr() (netip.Addr, error) {
|
||||
raw := p.next()
|
||||
if raw == "" {
|
||||
return netip.Addr{}, fmt.Errorf("missing IP address")
|
||||
}
|
||||
addr, err := netip.ParseAddr(raw)
|
||||
if err != nil {
|
||||
return netip.Addr{}, fmt.Errorf("invalid IP %q", raw)
|
||||
}
|
||||
return addr.Unmap(), nil
|
||||
}
|
||||
|
||||
func (p *parser) parsePort() (uint16, error) {
|
||||
raw := p.next()
|
||||
if raw == "" {
|
||||
return 0, fmt.Errorf("missing port number")
|
||||
}
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err != nil || n < 1 || n > 65535 {
|
||||
return 0, fmt.Errorf("invalid port %q", raw)
|
||||
}
|
||||
return uint16(n), nil
|
||||
}
|
||||
|
||||
func (p *parser) parsePrefix() (netip.Prefix, error) {
|
||||
raw := p.next()
|
||||
if raw == "" {
|
||||
return netip.Prefix{}, fmt.Errorf("missing network prefix")
|
||||
}
|
||||
prefix, err := netip.ParsePrefix(raw)
|
||||
if err != nil {
|
||||
return netip.Prefix{}, fmt.Errorf("invalid prefix %q", raw)
|
||||
}
|
||||
return prefix, nil
|
||||
}
|
||||
263
util/capture/filter_test.go
Normal file
263
util/capture/filter_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package capture
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// buildIPv4Packet creates a minimal IPv4+TCP/UDP packet for filter testing.
|
||||
func buildIPv4Packet(t *testing.T, srcIP, dstIP netip.Addr, proto uint8, srcPort, dstPort uint16) []byte {
|
||||
t.Helper()
|
||||
|
||||
hdrLen := 20
|
||||
pkt := make([]byte, hdrLen+20)
|
||||
pkt[0] = 0x45
|
||||
pkt[9] = proto
|
||||
|
||||
src := srcIP.As4()
|
||||
dst := dstIP.As4()
|
||||
copy(pkt[12:16], src[:])
|
||||
copy(pkt[16:20], dst[:])
|
||||
|
||||
pkt[20] = byte(srcPort >> 8)
|
||||
pkt[21] = byte(srcPort)
|
||||
pkt[22] = byte(dstPort >> 8)
|
||||
pkt[23] = byte(dstPort)
|
||||
|
||||
return pkt
|
||||
}
|
||||
|
||||
// buildIPv6Packet creates a minimal IPv6+TCP/UDP packet for filter testing.
|
||||
func buildIPv6Packet(t *testing.T, srcIP, dstIP netip.Addr, proto uint8, srcPort, dstPort uint16) []byte {
|
||||
t.Helper()
|
||||
|
||||
pkt := make([]byte, 44) // 40 header + 4 ports
|
||||
pkt[0] = 0x60 // version 6
|
||||
pkt[6] = proto // next header
|
||||
|
||||
src := srcIP.As16()
|
||||
dst := dstIP.As16()
|
||||
copy(pkt[8:24], src[:])
|
||||
copy(pkt[24:40], dst[:])
|
||||
|
||||
pkt[40] = byte(srcPort >> 8)
|
||||
pkt[41] = byte(srcPort)
|
||||
pkt[42] = byte(dstPort >> 8)
|
||||
pkt[43] = byte(dstPort)
|
||||
|
||||
return pkt
|
||||
}
|
||||
|
||||
// ---- Filter struct tests ----
|
||||
|
||||
func TestFilter_Empty(t *testing.T) {
|
||||
f := Filter{}
|
||||
assert.True(t, f.IsEmpty())
|
||||
assert.True(t, f.Match(buildIPv4Packet(t,
|
||||
netip.MustParseAddr("10.0.0.1"),
|
||||
netip.MustParseAddr("10.0.0.2"),
|
||||
protoTCP, 12345, 443)))
|
||||
}
|
||||
|
||||
func TestFilter_Host(t *testing.T) {
|
||||
f := Filter{Host: netip.MustParseAddr("10.0.0.1")}
|
||||
assert.True(t, f.Match(buildIPv4Packet(t, netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.2"), protoTCP, 1234, 80)))
|
||||
assert.True(t, f.Match(buildIPv4Packet(t, netip.MustParseAddr("10.0.0.2"), netip.MustParseAddr("10.0.0.1"), protoTCP, 1234, 80)))
|
||||
assert.False(t, f.Match(buildIPv4Packet(t, netip.MustParseAddr("10.0.0.2"), netip.MustParseAddr("10.0.0.3"), protoTCP, 1234, 80)))
|
||||
}
|
||||
|
||||
func TestFilter_InvalidPacket(t *testing.T) {
|
||||
f := Filter{Host: netip.MustParseAddr("10.0.0.1")}
|
||||
assert.False(t, f.Match(nil))
|
||||
assert.False(t, f.Match([]byte{}))
|
||||
assert.False(t, f.Match([]byte{0x00}))
|
||||
}
|
||||
|
||||
func TestParsePacketInfo_IPv4(t *testing.T) {
|
||||
pkt := buildIPv4Packet(t, netip.MustParseAddr("192.168.1.1"), netip.MustParseAddr("10.0.0.1"), protoTCP, 54321, 80)
|
||||
info, ok := parsePacketInfo(pkt)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, uint8(4), info.family)
|
||||
assert.Equal(t, netip.MustParseAddr("192.168.1.1"), info.srcIP)
|
||||
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), info.dstIP)
|
||||
assert.Equal(t, uint8(protoTCP), info.proto)
|
||||
assert.Equal(t, uint16(54321), info.srcPort)
|
||||
assert.Equal(t, uint16(80), info.dstPort)
|
||||
}
|
||||
|
||||
func TestParsePacketInfo_IPv6(t *testing.T) {
|
||||
pkt := buildIPv6Packet(t, netip.MustParseAddr("fd00::1"), netip.MustParseAddr("fd00::2"), protoUDP, 1234, 53)
|
||||
info, ok := parsePacketInfo(pkt)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, uint8(6), info.family)
|
||||
assert.Equal(t, netip.MustParseAddr("fd00::1"), info.srcIP)
|
||||
assert.Equal(t, netip.MustParseAddr("fd00::2"), info.dstIP)
|
||||
assert.Equal(t, uint8(protoUDP), info.proto)
|
||||
assert.Equal(t, uint16(1234), info.srcPort)
|
||||
assert.Equal(t, uint16(53), info.dstPort)
|
||||
}
|
||||
|
||||
// ---- ParseFilter expression tests ----
|
||||
|
||||
func matchV4(t *testing.T, m Matcher, srcIP, dstIP string, proto uint8, srcPort, dstPort uint16) bool {
|
||||
t.Helper()
|
||||
return m.Match(buildIPv4Packet(t, netip.MustParseAddr(srcIP), netip.MustParseAddr(dstIP), proto, srcPort, dstPort))
|
||||
}
|
||||
|
||||
func matchV6(t *testing.T, m Matcher, srcIP, dstIP string, proto uint8, srcPort, dstPort uint16) bool {
|
||||
t.Helper()
|
||||
return m.Match(buildIPv6Packet(t, netip.MustParseAddr(srcIP), netip.MustParseAddr(dstIP), proto, srcPort, dstPort))
|
||||
}
|
||||
|
||||
func TestParseFilter_Empty(t *testing.T) {
|
||||
m, err := ParseFilter("")
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, m, "empty expression should return nil matcher")
|
||||
}
|
||||
|
||||
func TestParseFilter_Atoms(t *testing.T) {
|
||||
tests := []struct {
|
||||
expr string
|
||||
match bool
|
||||
}{
|
||||
{"tcp", true},
|
||||
{"udp", false},
|
||||
{"host 10.0.0.1", true},
|
||||
{"host 10.0.0.99", false},
|
||||
{"port 443", true},
|
||||
{"port 80", false},
|
||||
{"src host 10.0.0.1", true},
|
||||
{"dst host 10.0.0.2", true},
|
||||
{"dst host 10.0.0.1", false},
|
||||
{"src port 12345", true},
|
||||
{"dst port 443", true},
|
||||
{"dst port 80", false},
|
||||
{"proto 6", true},
|
||||
{"proto 17", false},
|
||||
}
|
||||
|
||||
pkt := buildIPv4Packet(t, netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.2"), protoTCP, 12345, 443)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.expr, func(t *testing.T) {
|
||||
m, err := ParseFilter(tt.expr)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.match, m.Match(pkt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilter_And(t *testing.T) {
|
||||
m, err := ParseFilter("host 10.0.0.1 and tcp port 443")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 55555, 443))
|
||||
assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoUDP, 55555, 443), "wrong proto")
|
||||
assert.False(t, matchV4(t, m, "10.0.0.3", "10.0.0.2", protoTCP, 55555, 443), "wrong host")
|
||||
assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 55555, 80), "wrong port")
|
||||
}
|
||||
|
||||
func TestParseFilter_ImplicitAnd(t *testing.T) {
|
||||
// "tcp port 443" = implicit AND between tcp and port 443
|
||||
m, err := ParseFilter("tcp port 443")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 443))
|
||||
assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoUDP, 1, 443))
|
||||
}
|
||||
|
||||
func TestParseFilter_Or(t *testing.T) {
|
||||
m, err := ParseFilter("port 80 or port 443")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 80))
|
||||
assert.True(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 443))
|
||||
assert.False(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 8080))
|
||||
}
|
||||
|
||||
func TestParseFilter_Not(t *testing.T) {
|
||||
m, err := ParseFilter("not port 22")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 443))
|
||||
assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 22))
|
||||
assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 22, 80))
|
||||
}
|
||||
|
||||
func TestParseFilter_Parens(t *testing.T) {
|
||||
m, err := ParseFilter("(port 80 or port 443) and tcp")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 443))
|
||||
assert.False(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoUDP, 1, 443), "wrong proto")
|
||||
assert.False(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 8080), "wrong port")
|
||||
}
|
||||
|
||||
func TestParseFilter_Family(t *testing.T) {
|
||||
mV4, err := ParseFilter("ip")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, matchV4(t, mV4, "10.0.0.1", "10.0.0.2", protoTCP, 1, 80))
|
||||
assert.False(t, matchV6(t, mV4, "fd00::1", "fd00::2", protoTCP, 1, 80))
|
||||
|
||||
mV6, err := ParseFilter("ip6")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, matchV4(t, mV6, "10.0.0.1", "10.0.0.2", protoTCP, 1, 80))
|
||||
assert.True(t, matchV6(t, mV6, "fd00::1", "fd00::2", protoTCP, 1, 80))
|
||||
}
|
||||
|
||||
func TestParseFilter_Net(t *testing.T) {
|
||||
m, err := ParseFilter("net 10.0.0.0/24")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, matchV4(t, m, "10.0.0.1", "192.168.1.1", protoTCP, 1, 80), "src in net")
|
||||
assert.True(t, matchV4(t, m, "192.168.1.1", "10.0.0.200", protoTCP, 1, 80), "dst in net")
|
||||
assert.False(t, matchV4(t, m, "10.0.1.1", "192.168.1.1", protoTCP, 1, 80), "neither in net")
|
||||
}
|
||||
|
||||
func TestParseFilter_SrcDstNet(t *testing.T) {
|
||||
m, err := ParseFilter("src net 10.0.0.0/8 and dst net 192.168.0.0/16")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, matchV4(t, m, "10.1.2.3", "192.168.1.1", protoTCP, 1, 80))
|
||||
assert.False(t, matchV4(t, m, "192.168.1.1", "10.1.2.3", protoTCP, 1, 80), "reversed")
|
||||
}
|
||||
|
||||
func TestParseFilter_Complex(t *testing.T) {
|
||||
// Real-world: capture HTTP(S) traffic to/from specific host, excluding SSH
|
||||
m, err := ParseFilter("host 10.0.0.1 and (port 80 or port 443) and not port 22")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 55555, 443))
|
||||
assert.True(t, matchV4(t, m, "10.0.0.2", "10.0.0.1", protoTCP, 55555, 80))
|
||||
assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 22, 443), "port 22 excluded")
|
||||
assert.False(t, matchV4(t, m, "10.0.0.3", "10.0.0.2", protoTCP, 55555, 443), "wrong host")
|
||||
}
|
||||
|
||||
func TestParseFilter_IPv6Combined(t *testing.T) {
|
||||
m, err := ParseFilter("ip6 and icmp6")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, matchV6(t, m, "fd00::1", "fd00::2", protoICMPv6, 0, 0))
|
||||
assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoICMP, 0, 0), "wrong family")
|
||||
assert.False(t, matchV6(t, m, "fd00::1", "fd00::2", protoTCP, 1, 80), "wrong proto")
|
||||
}
|
||||
|
||||
func TestParseFilter_CaseInsensitive(t *testing.T) {
|
||||
m, err := ParseFilter("HOST 10.0.0.1 AND TCP PORT 443")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 443))
|
||||
}
|
||||
|
||||
func TestParseFilter_Errors(t *testing.T) {
|
||||
bad := []string{
|
||||
"badkeyword",
|
||||
"host",
|
||||
"port abc",
|
||||
"port 99999",
|
||||
"net invalid",
|
||||
"(",
|
||||
"(port 80",
|
||||
"not",
|
||||
"src",
|
||||
}
|
||||
for _, expr := range bad {
|
||||
t.Run(expr, func(t *testing.T) {
|
||||
_, err := ParseFilter(expr)
|
||||
assert.Error(t, err, "should fail for %q", expr)
|
||||
})
|
||||
}
|
||||
}
|
||||
85
util/capture/pcap.go
Normal file
85
util/capture/pcap.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package capture
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
pcapMagic = 0xa1b2c3d4
|
||||
pcapVersionMaj = 2
|
||||
pcapVersionMin = 4
|
||||
// linkTypeRaw is LINKTYPE_RAW: raw IPv4/IPv6 packets without link-layer header.
|
||||
linkTypeRaw = 101
|
||||
defaultSnapLen = 65535
|
||||
)
|
||||
|
||||
// PcapWriter writes packets in pcap format to an underlying writer.
|
||||
// The global header is written lazily on the first WritePacket call so that
|
||||
// the writer can be used with unbuffered io.Pipes without deadlocking.
|
||||
// It is not safe for concurrent use; callers must serialize access.
|
||||
type PcapWriter struct {
|
||||
w io.Writer
|
||||
snapLen uint32
|
||||
headerWritten bool
|
||||
}
|
||||
|
||||
// NewPcapWriter creates a pcap writer. The global header is deferred until the
|
||||
// first WritePacket call.
|
||||
func NewPcapWriter(w io.Writer, snapLen uint32) *PcapWriter {
|
||||
if snapLen == 0 {
|
||||
snapLen = defaultSnapLen
|
||||
}
|
||||
return &PcapWriter{w: w, snapLen: snapLen}
|
||||
}
|
||||
|
||||
// writeGlobalHeader writes the 24-byte pcap file header.
|
||||
func (pw *PcapWriter) writeGlobalHeader() error {
|
||||
var hdr [24]byte
|
||||
binary.LittleEndian.PutUint32(hdr[0:4], pcapMagic)
|
||||
binary.LittleEndian.PutUint16(hdr[4:6], pcapVersionMaj)
|
||||
binary.LittleEndian.PutUint16(hdr[6:8], pcapVersionMin)
|
||||
binary.LittleEndian.PutUint32(hdr[16:20], pw.snapLen)
|
||||
binary.LittleEndian.PutUint32(hdr[20:24], linkTypeRaw)
|
||||
|
||||
_, err := pw.w.Write(hdr[:])
|
||||
return err
|
||||
}
|
||||
|
||||
// WriteHeader writes the pcap global header. Safe to call multiple times.
|
||||
func (pw *PcapWriter) WriteHeader() error {
|
||||
if pw.headerWritten {
|
||||
return nil
|
||||
}
|
||||
if err := pw.writeGlobalHeader(); err != nil {
|
||||
return err
|
||||
}
|
||||
pw.headerWritten = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// WritePacket writes a single packet record, preceded by the global header
|
||||
// on the first call.
|
||||
func (pw *PcapWriter) WritePacket(ts time.Time, data []byte) error {
|
||||
if err := pw.WriteHeader(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
origLen := uint32(len(data))
|
||||
if origLen > pw.snapLen {
|
||||
data = data[:pw.snapLen]
|
||||
}
|
||||
|
||||
var hdr [16]byte
|
||||
binary.LittleEndian.PutUint32(hdr[0:4], uint32(ts.Unix()))
|
||||
binary.LittleEndian.PutUint32(hdr[4:8], uint32(ts.Nanosecond()/1000))
|
||||
binary.LittleEndian.PutUint32(hdr[8:12], uint32(len(data)))
|
||||
binary.LittleEndian.PutUint32(hdr[12:16], origLen)
|
||||
|
||||
if _, err := pw.w.Write(hdr[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := pw.w.Write(data)
|
||||
return err
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user