mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-01 07:04:17 -04:00
Compare commits
1 Commits
netmap
...
feature/ip
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8b0398c0db |
2
.github/workflows/golang-test-darwin.yml
vendored
2
.github/workflows/golang-test-darwin.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
store: ['sqlite']
|
||||
store: ['jsonfile', 'sqlite']
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
|
||||
39
.github/workflows/golang-test-freebsd.yml
vendored
39
.github/workflows/golang-test-freebsd.yml
vendored
@@ -1,39 +0,0 @@
|
||||
|
||||
name: Test Code FreeBSD
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Test in FreeBSD
|
||||
id: test
|
||||
uses: vmactions/freebsd-vm@v1
|
||||
with:
|
||||
usesh: true
|
||||
prepare: |
|
||||
pkg install -y curl
|
||||
pkg install -y git
|
||||
|
||||
run: |
|
||||
set -x
|
||||
curl -o go.tar.gz https://go.dev/dl/go1.21.11.freebsd-amd64.tar.gz -L
|
||||
tar zxf go.tar.gz
|
||||
mv go /usr/local/go
|
||||
ln -s /usr/local/go/bin/go /usr/local/bin/go
|
||||
go mod tidy
|
||||
go test -timeout 5m -p 1 ./iface/...
|
||||
go test -timeout 5m -p 1 ./client/...
|
||||
cd client
|
||||
go build .
|
||||
cd ..
|
||||
10
.github/workflows/golang-test-linux.yml
vendored
10
.github/workflows/golang-test-linux.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres']
|
||||
store: [ 'jsonfile', 'sqlite', 'postgres']
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
@@ -86,10 +86,7 @@ jobs:
|
||||
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
|
||||
|
||||
- name: Generate RouteManager Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager
|
||||
|
||||
- name: Generate SystemOps Test bin
|
||||
run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops
|
||||
run: CGO_ENABLED=1 go test -c -o routemanager-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/...
|
||||
|
||||
- name: Generate nftables Manager Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
|
||||
@@ -111,9 +108,6 @@ jobs:
|
||||
- name: Run RouteManager tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run SystemOps tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run nftables Manager tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -173,7 +173,7 @@ jobs:
|
||||
retention-days: 3
|
||||
|
||||
release_ui_darwin:
|
||||
runs-on: macos-latest
|
||||
runs-on: macos-11
|
||||
steps:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
|
||||
59
.github/workflows/test-infrastructure-files.yml
vendored
59
.github/workflows/test-infrastructure-files.yml
vendored
@@ -178,79 +178,34 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: run script with Zitadel PostgreSQL
|
||||
- name: run script
|
||||
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
|
||||
|
||||
- name: test Caddy file gen postgres
|
||||
- name: test Caddy file gen
|
||||
run: test -f Caddyfile
|
||||
|
||||
- name: test docker-compose file gen postgres
|
||||
- name: test docker-compose file gen
|
||||
run: test -f docker-compose.yml
|
||||
|
||||
- name: test management.json file gen postgres
|
||||
- name: test management.json file gen
|
||||
run: test -f management.json
|
||||
|
||||
- name: test turnserver.conf file gen postgres
|
||||
- name: test turnserver.conf file gen
|
||||
run: |
|
||||
set -x
|
||||
test -f turnserver.conf
|
||||
grep external-ip turnserver.conf
|
||||
|
||||
- name: test zitadel.env file gen postgres
|
||||
- name: test zitadel.env file gen
|
||||
run: test -f zitadel.env
|
||||
|
||||
- name: test dashboard.env file gen postgres
|
||||
- name: test dashboard.env file gen
|
||||
run: test -f dashboard.env
|
||||
|
||||
- name: test zdb.env file gen postgres
|
||||
run: test -f zdb.env
|
||||
|
||||
- name: Postgres run cleanup
|
||||
run: |
|
||||
docker-compose down --volumes --rmi all
|
||||
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
|
||||
|
||||
- name: run script with Zitadel CockroachDB
|
||||
run: bash -x infrastructure_files/getting-started-with-zitadel.sh
|
||||
env:
|
||||
NETBIRD_DOMAIN: use-ip
|
||||
ZITADEL_DATABASE: cockroach
|
||||
|
||||
- name: test Caddy file gen CockroachDB
|
||||
run: test -f Caddyfile
|
||||
|
||||
- name: test docker-compose file gen CockroachDB
|
||||
run: test -f docker-compose.yml
|
||||
|
||||
- name: test management.json file gen CockroachDB
|
||||
run: test -f management.json
|
||||
|
||||
- name: test turnserver.conf file gen CockroachDB
|
||||
run: |
|
||||
set -x
|
||||
test -f turnserver.conf
|
||||
grep external-ip turnserver.conf
|
||||
|
||||
- name: test zitadel.env file gen CockroachDB
|
||||
run: test -f zitadel.env
|
||||
|
||||
- name: test dashboard.env file gen CockroachDB
|
||||
run: test -f dashboard.env
|
||||
|
||||
test-download-geolite2-script:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install jq
|
||||
run: sudo apt-get update && sudo apt-get install -y unzip sqlite3
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: test script
|
||||
run: bash -x infrastructure_files/download-geolite2.sh
|
||||
|
||||
- name: test mmdb file exists
|
||||
run: test -f GeoLite2-City.mmdb
|
||||
|
||||
- name: test geonames file exists
|
||||
run: test -f geonames.db
|
||||
|
||||
@@ -3,10 +3,8 @@ builds:
|
||||
- id: netbird-ui-darwin
|
||||
dir: client/ui
|
||||
binary: netbird-ui
|
||||
env:
|
||||
- CGO_ENABLED=1
|
||||
- MACOSX_DEPLOYMENT_TARGET=11.0
|
||||
- MACOS_DEPLOYMENT_TARGET=11.0
|
||||
env: [CGO_ENABLED=1]
|
||||
|
||||
goos:
|
||||
- darwin
|
||||
goarch:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM alpine:3.19
|
||||
FROM alpine:3.18.5
|
||||
RUN apk add --no-cache ca-certificates iptables ip6tables
|
||||
ENV NB_FOREGROUND_MODE=true
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
||||
|
||||
@@ -59,7 +59,7 @@ var forCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
conn, err := getClient(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -80,7 +80,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
}
|
||||
|
||||
func setLogLevel(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
conn, err := getClient(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -109,7 +109,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("invalid duration format: %v", err)
|
||||
}
|
||||
|
||||
conn, err := getClient(cmd)
|
||||
conn, err := getClient(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -36,7 +36,6 @@ const (
|
||||
disableAutoConnectFlag = "disable-auto-connect"
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -69,9 +68,7 @@ var (
|
||||
autoConnectDisabled bool
|
||||
extraIFaceBlackList []string
|
||||
anonymizeFlag bool
|
||||
dnsRouteInterval time.Duration
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
Short: "",
|
||||
Long: "",
|
||||
@@ -356,11 +353,8 @@ func migrateToNetbird(oldPath, newPath string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||
func getClient(ctx context.Context) (*grpc.ClientConn, error) {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
|
||||
@@ -2,7 +2,6 @@ package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
@@ -50,7 +49,7 @@ func init() {
|
||||
}
|
||||
|
||||
func routesList(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
conn, err := getClient(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -67,62 +66,20 @@ func routesList(cmd *cobra.Command, _ []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
printRoutes(cmd, resp)
|
||||
cmd.Println("Available Routes:")
|
||||
for _, route := range resp.Routes {
|
||||
selectedStatus := "Not Selected"
|
||||
if route.GetSelected() {
|
||||
selectedStatus = "Selected"
|
||||
}
|
||||
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func printRoutes(cmd *cobra.Command, resp *proto.ListRoutesResponse) {
|
||||
cmd.Println("Available Routes:")
|
||||
for _, route := range resp.Routes {
|
||||
printRoute(cmd, route)
|
||||
}
|
||||
}
|
||||
|
||||
func printRoute(cmd *cobra.Command, route *proto.Route) {
|
||||
selectedStatus := getSelectedStatus(route)
|
||||
domains := route.GetDomains()
|
||||
|
||||
if len(domains) > 0 {
|
||||
printDomainRoute(cmd, route, domains, selectedStatus)
|
||||
} else {
|
||||
printNetworkRoute(cmd, route, selectedStatus)
|
||||
}
|
||||
}
|
||||
|
||||
func getSelectedStatus(route *proto.Route) string {
|
||||
if route.GetSelected() {
|
||||
return "Selected"
|
||||
}
|
||||
return "Not Selected"
|
||||
}
|
||||
|
||||
func printDomainRoute(cmd *cobra.Command, route *proto.Route, domains []string, selectedStatus string) {
|
||||
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
|
||||
resolvedIPs := route.GetResolvedIPs()
|
||||
|
||||
if len(resolvedIPs) > 0 {
|
||||
printResolvedIPs(cmd, domains, resolvedIPs)
|
||||
} else {
|
||||
cmd.Printf(" Resolved IPs: -\n")
|
||||
}
|
||||
}
|
||||
|
||||
func printNetworkRoute(cmd *cobra.Command, route *proto.Route, selectedStatus string) {
|
||||
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
|
||||
}
|
||||
|
||||
func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) {
|
||||
cmd.Printf(" Resolved IPs:\n")
|
||||
for _, domain := range domains {
|
||||
if ipList, exists := resolvedIPs[domain]; exists {
|
||||
cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", "))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func routesSelect(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
conn, err := getClient(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -149,7 +106,7 @@ func routesSelect(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
func routesDeselect(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
conn, err := getClient(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -807,7 +807,11 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
|
||||
}
|
||||
|
||||
for i, route := range peer.Routes {
|
||||
peer.Routes[i] = anonymizeRoute(a, route)
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err == nil {
|
||||
ip := a.AnonymizeIPString(prefix.Addr().String())
|
||||
peer.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -843,21 +847,12 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview)
|
||||
}
|
||||
|
||||
for i, route := range overview.Routes {
|
||||
overview.Routes[i] = anonymizeRoute(a, route)
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err == nil {
|
||||
ip := a.AnonymizeIPString(prefix.Addr().String())
|
||||
overview.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
|
||||
}
|
||||
}
|
||||
|
||||
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
|
||||
}
|
||||
|
||||
func anonymizeRoute(a *anonymize.Anonymizer, route string) string {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err == nil {
|
||||
ip := a.AnonymizeIPString(prefix.Addr().String())
|
||||
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
|
||||
}
|
||||
domains := strings.Split(route, ", ")
|
||||
for i, domain := range domains {
|
||||
domains[i] = a.AnonymizeDomain(domain)
|
||||
}
|
||||
return strings.Join(domains, ", ")
|
||||
}
|
||||
|
||||
@@ -7,9 +7,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
@@ -56,10 +53,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
srv, err := sig.NewServer(otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
|
||||
sigProto.RegisterSignalExchangeServer(s, srv)
|
||||
sigProto.RegisterSignalExchangeServer(s, sig.NewServer())
|
||||
go func() {
|
||||
if err := s.Serve(lis); err != nil {
|
||||
panic(err)
|
||||
@@ -76,7 +70,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
store, cleanUp, err := mgmt.NewTestStoreFromJson(config.Datadir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -87,13 +81,13 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
|
||||
iv, _ := integrations.NewIntegratedValidator(eventStore)
|
||||
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -108,7 +102,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
||||
}
|
||||
|
||||
func startClientDaemon(
|
||||
t *testing.T, ctx context.Context, _, configPath string,
|
||||
t *testing.T, ctx context.Context, managementURL, configPath string,
|
||||
) (*grpc.Server, net.Listener) {
|
||||
t.Helper()
|
||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
@@ -7,13 +7,11 @@ import (
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
@@ -42,12 +40,8 @@ func init() {
|
||||
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
|
||||
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
||||
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
||||
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
|
||||
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+
|
||||
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
|
||||
)
|
||||
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", false, "Enable network monitoring")
|
||||
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
||||
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
||||
}
|
||||
|
||||
func upFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -143,10 +137,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsRouteIntervalFlag).Changed {
|
||||
ic.DNSRouteInterval = &dnsRouteInterval
|
||||
}
|
||||
|
||||
config, err := internal.UpdateOrCreateConfig(ic)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
@@ -247,10 +237,6 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
loginRequest.NetworkMonitor = &networkMonitor
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsRouteIntervalFlag).Changed {
|
||||
loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval)
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
|
||||
var loginResp *proto.LoginResponse
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
)
|
||||
|
||||
func formatError(es []error) string {
|
||||
if len(es) == 0 {
|
||||
return fmt.Sprintf("0 error occurred:\n\t* %s", es[0])
|
||||
}
|
||||
|
||||
points := make([]string, len(es))
|
||||
for i, err := range es {
|
||||
points[i] = fmt.Sprintf("* %s", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%d errors occurred:\n\t%s",
|
||||
len(es), strings.Join(points, "\n\t"))
|
||||
}
|
||||
|
||||
func FormatErrorOrNil(err *multierror.Error) error {
|
||||
if err != nil {
|
||||
err.ErrorFormat = formatError
|
||||
}
|
||||
return err.ErrorOrNil()
|
||||
}
|
||||
@@ -30,3 +30,9 @@ func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager,
|
||||
}
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
// Returns true if the current firewall implementation supports IPv6.
|
||||
// Currently false for anything non-linux.
|
||||
func SupportsIPv6() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -70,6 +70,8 @@ func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager,
|
||||
return nil, errUsp
|
||||
}
|
||||
|
||||
// Note for devs: When adding IPv6 support to userspace bind, the implementation of AllowNetbird() has to be
|
||||
// adjusted accordingly.
|
||||
if err := fm.AllowNetbird(); err != nil {
|
||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||
}
|
||||
@@ -83,6 +85,12 @@ func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager,
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
// Returns true if the current firewall implementation supports IPv6.
|
||||
// Currently true if the firewall is nftables.
|
||||
func SupportsIPv6() bool {
|
||||
return check() == NFTABLES
|
||||
}
|
||||
|
||||
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
||||
func check() FWType {
|
||||
useIPTABLES := false
|
||||
|
||||
@@ -6,6 +6,7 @@ import "github.com/netbirdio/netbird/iface"
|
||||
type IFaceMapper interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
Address6() *iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
SetFilter(iface.PacketFilter) error
|
||||
}
|
||||
|
||||
@@ -24,6 +24,14 @@ type Manager struct {
|
||||
router *routerManager
|
||||
}
|
||||
|
||||
func (m *Manager) ResetV6Firewall() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) V6Active() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
|
||||
@@ -74,12 +74,12 @@ func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
|
||||
err = i.insertRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
|
||||
err = i.insertRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -101,7 +101,6 @@ func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string,
|
||||
}
|
||||
delete(i.rules, ruleKey)
|
||||
}
|
||||
|
||||
err = i.iptablesClient.Insert(table, chain, 1, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
||||
@@ -318,13 +317,6 @@ func (i *routerManager) createChain(table, newChain string) error {
|
||||
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
|
||||
}
|
||||
|
||||
// Add the loopback return rule to the NAT chain
|
||||
loopbackRule := []string{"-o", "lo", "-j", "RETURN"}
|
||||
err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err)
|
||||
}
|
||||
|
||||
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
|
||||
@@ -334,30 +326,6 @@ func (i *routerManager) createChain(table, newChain string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// addNATRule appends an iptables rule pair to the nat chain
|
||||
func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
||||
rule := genRuleSpec(jump, pair.Source, pair.Destination)
|
||||
existingRule, found := i.rules[ruleKey]
|
||||
if found {
|
||||
err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
delete(i.rules, ruleKey)
|
||||
}
|
||||
|
||||
// inserting after loopback ignore rule
|
||||
err := i.iptablesClient.Insert(table, chain, 2, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
i.rules[ruleKey] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// genRuleSpec generates rule specification
|
||||
func genRuleSpec(jump, source, destination string) []string {
|
||||
return []string{"-s", source, "-d", destination, "-j", jump}
|
||||
|
||||
@@ -73,6 +73,9 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
||||
|
||||
for _, testCase := range test.InsertRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
if testCase.IsV6 {
|
||||
t.Skip("Environment does not support IPv6, skipping IPv6 test...")
|
||||
}
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
@@ -154,6 +157,9 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
|
||||
|
||||
for _, testCase := range test.RemoveRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
if testCase.IsV6 {
|
||||
t.Skip("Environment does not support IPv6, skipping IPv6 test...")
|
||||
}
|
||||
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
|
||||
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
||||
|
||||
@@ -76,6 +76,13 @@ type Manager interface {
|
||||
// RemoveRoutingRules removes a routing firewall rule
|
||||
RemoveRoutingRules(pair RouterPair) error
|
||||
|
||||
// ResetV6Firewall makes changes to the firewall to adapt to the IP address changes.
|
||||
// It is expected that after calling this method ApplyFiltering will be called to re-add the firewall rules.
|
||||
ResetV6Firewall() error
|
||||
|
||||
// V6Active returns whether IPv6 rules should/may be created by upper layers.
|
||||
V6Active() bool
|
||||
|
||||
// Reset firewall to the default state
|
||||
Reset() error
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -36,17 +36,25 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
||||
wgIface: wgIface,
|
||||
}
|
||||
|
||||
workTable, err := m.createWorkTable()
|
||||
workTable, err := m.createWorkTable(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.router, err = newRouter(context, workTable)
|
||||
var workTable6 *nftables.Table
|
||||
if wgIface.Address6() != nil {
|
||||
workTable6, err = m.createWorkTable(nftables.TableFamilyIPv6)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
m.router, err = newRouter(context, workTable, workTable6)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName())
|
||||
m.aclManager, err = newAclManager(workTable, workTable6, wgIface, m.router.RouteingFwChainName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -54,6 +62,54 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Resets the IPv6 Firewall Table to adapt to changes in IP addresses
|
||||
func (m *Manager) ResetV6Firewall() error {
|
||||
|
||||
// First, prepare reset by deleting all currently active rules.
|
||||
workTable6, err := m.aclManager.PrepareV6Reset()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Depending on whether we now have an IPv6 address, we now either have to create/empty an IPv6 table, or delete it.
|
||||
if m.wgIface.Address6() != nil {
|
||||
if workTable6 != nil {
|
||||
m.rConn.FlushTable(workTable6)
|
||||
} else {
|
||||
workTable6, err = m.createWorkTable(nftables.TableFamilyIPv6)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
m.rConn.DelTable(workTable6)
|
||||
workTable6 = nil
|
||||
}
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Restore routing rules.
|
||||
err = m.router.RestoreAfterV6Reset(workTable6)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Restore basic firewall chains (needs to happen after routes because chains from router must exist).
|
||||
// Does not restore rules (will be done later during the update, when UpdateFiltering will be called at some point)
|
||||
err = m.aclManager.ReinitAfterV6Reset(workTable6)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return m.rConn.Flush()
|
||||
}
|
||||
|
||||
func (m *Manager) V6Active() bool {
|
||||
return m.aclManager.v6Active
|
||||
}
|
||||
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
@@ -72,7 +128,7 @@ func (m *Manager) AddFiltering(
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
rawIP := ip.To4()
|
||||
if rawIP == nil {
|
||||
if rawIP == nil && m.wgIface.Address6() == nil {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||
}
|
||||
|
||||
@@ -95,7 +151,7 @@ func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddRoutingRules(pair)
|
||||
return m.router.InsertRoutingRules(pair)
|
||||
}
|
||||
|
||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||
@@ -114,6 +170,8 @@ func (m *Manager) AllowNetbird() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
// Note for devs: When adding IPv6 support to uspfilter, the implementation of createDefaultAllowRules()
|
||||
// must be adjusted to include IPv6 rules.
|
||||
err := m.aclManager.createDefaultAllowRules()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create default allow rules: %v", err)
|
||||
@@ -211,8 +269,8 @@ func (m *Manager) Flush() error {
|
||||
return m.aclManager.Flush()
|
||||
}
|
||||
|
||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
func (m *Manager) createWorkTable(tableFamily nftables.TableFamily) (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(tableFamily)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
@@ -223,7 +281,7 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
}
|
||||
}
|
||||
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: tableFamily})
|
||||
err = m.rConn.Flush()
|
||||
return table, err
|
||||
}
|
||||
|
||||
@@ -19,8 +19,9 @@ import (
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMock struct {
|
||||
NameFunc func() string
|
||||
AddressFunc func() iface.WGAddress
|
||||
NameFunc func() string
|
||||
AddressFunc func() iface.WGAddress
|
||||
Address6Func func() *iface.WGAddress
|
||||
}
|
||||
|
||||
func (i *iFaceMock) Name() string {
|
||||
@@ -37,6 +38,13 @@ func (i *iFaceMock) Address() iface.WGAddress {
|
||||
panic("AddressFunc is not set")
|
||||
}
|
||||
|
||||
func (i *iFaceMock) Address6() *iface.WGAddress {
|
||||
if i.Address6Func != nil {
|
||||
return i.Address6Func()
|
||||
}
|
||||
panic("AddressFunc is not set")
|
||||
}
|
||||
|
||||
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
||||
|
||||
func TestNftablesManager(t *testing.T) {
|
||||
@@ -53,6 +61,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
},
|
||||
}
|
||||
},
|
||||
Address6Func: func() *iface.WGAddress { return nil },
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
@@ -99,11 +108,9 @@ func TestNftablesManager(t *testing.T) {
|
||||
Register: 1,
|
||||
Data: ifname("lo"),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: uint32(9),
|
||||
Len: uint32(1),
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyL4PROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Register: 1,
|
||||
@@ -152,6 +159,370 @@ func TestNftablesManager(t *testing.T) {
|
||||
require.NoError(t, err, "failed to reset")
|
||||
}
|
||||
|
||||
func TestNftablesManager6Disabled(t *testing.T) {
|
||||
mock := &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
Address6Func: func() *iface.WGAddress { return nil },
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(context.Background(), mock)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
defer func() {
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "failed to reset")
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
ip := net.ParseIP("2001:db8::fedc:ba09:8765:4321")
|
||||
|
||||
testClient := &nftables.Conn{}
|
||||
|
||||
_, err = manager.AddFiltering(
|
||||
ip,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []int{53}},
|
||||
fw.RuleDirectionIN,
|
||||
fw.ActionDrop,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
require.Error(t, err, "IPv6 rule should not be added when IPv6 is disabled")
|
||||
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
|
||||
require.Len(t, rules, 0, "expected no rules")
|
||||
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "failed to reset")
|
||||
}
|
||||
|
||||
func TestNftablesManager6(t *testing.T) {
|
||||
|
||||
if !iface.SupportsIPv6() {
|
||||
t.Skip("Environment does not support IPv6, skipping IPv6 test...")
|
||||
}
|
||||
mock := &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
Address6Func: func() *iface.WGAddress {
|
||||
return &iface.WGAddress{
|
||||
IP: net.ParseIP("2001:db8::0123:4567:890a:bcde"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("2001:db8::"),
|
||||
Mask: net.CIDRMask(64, 128),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(context.Background(), mock)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
defer func() {
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "failed to reset")
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
require.True(t, manager.V6Active(), "IPv6 is not active even though it should be.")
|
||||
|
||||
ip := net.ParseIP("2001:db8::fedc:ba09:8765:4321")
|
||||
|
||||
testClient := &nftables.Conn{}
|
||||
|
||||
rule, err := manager.AddFiltering(
|
||||
ip,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []int{53}},
|
||||
fw.RuleDirectionIN,
|
||||
fw.ActionDrop,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err := testClient.GetRules(manager.aclManager.workTable6, manager.aclManager.chainInputRules6)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
|
||||
require.Len(t, rules, 1, "expected 1 rules")
|
||||
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
expectedExprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname("lo"),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyL4PROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Register: 1,
|
||||
Op: expr.CmpOpEq,
|
||||
Data: []byte{unix.IPPROTO_TCP},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 8,
|
||||
Len: 16,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: add.AsSlice(),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{0, 53},
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||
}
|
||||
require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
|
||||
|
||||
for _, r := range rule {
|
||||
err = manager.DeleteRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err = testClient.GetRules(manager.aclManager.workTable6, manager.aclManager.chainInputRules6)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
require.Len(t, rules, 0, "expected 0 rules after deletion")
|
||||
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "failed to reset")
|
||||
}
|
||||
|
||||
func TestNftablesManagerAddressReset6(t *testing.T) {
|
||||
|
||||
if !iface.SupportsIPv6() {
|
||||
t.Skip("Environment does not support IPv6, skipping IPv6 test...")
|
||||
}
|
||||
mock := &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
Address6Func: func() *iface.WGAddress {
|
||||
return &iface.WGAddress{
|
||||
IP: net.ParseIP("2001:db8::0123:4567:890a:bcde"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("2001:db8::"),
|
||||
Mask: net.CIDRMask(64, 128),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(context.Background(), mock)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
defer func() {
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "failed to reset")
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
require.True(t, manager.V6Active(), "IPv6 is not active even though it should be.")
|
||||
|
||||
ip := net.ParseIP("2001:db8::fedc:ba09:8765:4321")
|
||||
|
||||
testClient := &nftables.Conn{}
|
||||
|
||||
_, err = manager.AddFiltering(
|
||||
ip,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []int{53}},
|
||||
fw.RuleDirectionIN,
|
||||
fw.ActionDrop,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err := testClient.GetRules(manager.aclManager.workTable6, manager.aclManager.chainInputRules6)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
|
||||
require.Len(t, rules, 1, "expected 1 rules")
|
||||
|
||||
mock.Address6Func = func() *iface.WGAddress {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = manager.ResetV6Firewall()
|
||||
require.NoError(t, err, "failed to reset IPv6 firewall")
|
||||
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
require.False(t, manager.V6Active(), "IPv6 is active even though it shouldn't be.")
|
||||
|
||||
tables, err := testClient.ListTablesOfFamily(nftables.TableFamilyIPv6)
|
||||
require.NoError(t, err, "failed to list IPv6 tables")
|
||||
|
||||
for _, table := range tables {
|
||||
if table.Name == tableName {
|
||||
t.Errorf("When IPv6 is disabled, the netbird table should not exist.")
|
||||
}
|
||||
}
|
||||
|
||||
mock.Address6Func = func() *iface.WGAddress {
|
||||
return &iface.WGAddress{
|
||||
IP: net.ParseIP("2001:db8::0123:4567:890a:bcdf"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("2001:db8::"),
|
||||
Mask: net.CIDRMask(64, 128),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
err = manager.ResetV6Firewall()
|
||||
require.NoError(t, err, "failed to reset IPv6 firewall")
|
||||
|
||||
require.True(t, manager.V6Active(), "IPv6 is not active even though it should be.")
|
||||
|
||||
rule, err := manager.AddFiltering(
|
||||
ip,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []int{53}},
|
||||
fw.RuleDirectionIN,
|
||||
fw.ActionDrop,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err = testClient.GetRules(manager.aclManager.workTable6, manager.aclManager.chainInputRules6)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
|
||||
require.Len(t, rules, 1, "expected 1 rule")
|
||||
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
expectedExprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname("lo"),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyL4PROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Register: 1,
|
||||
Op: expr.CmpOpEq,
|
||||
Data: []byte{unix.IPPROTO_TCP},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 8,
|
||||
Len: 16,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: add.AsSlice(),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{0, 53},
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||
}
|
||||
require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
|
||||
|
||||
for _, r := range rule {
|
||||
err = manager.DeleteRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err = testClient.GetRules(manager.aclManager.workTable6, manager.aclManager.chainInputRules6)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
require.Len(t, rules, 0, "expected 0 rules after deletion")
|
||||
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "failed to reset")
|
||||
}
|
||||
|
||||
func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
mock := &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
@@ -166,6 +537,16 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
},
|
||||
}
|
||||
},
|
||||
Address6Func: func() *iface.WGAddress {
|
||||
v6addr, v6net, _ := net.ParseCIDR("fd00:1234:dead:beef::1/64")
|
||||
return &iface.WGAddress{
|
||||
IP: v6addr,
|
||||
Network: &net.IPNet{
|
||||
IP: v6net.IP,
|
||||
Mask: v6net.Mask,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||
|
||||
@@ -22,13 +22,12 @@ const (
|
||||
|
||||
userDataAcceptForwardRuleSrc = "frwacceptsrc"
|
||||
userDataAcceptForwardRuleDst = "frwacceptdst"
|
||||
|
||||
loopbackInterface = "lo\x00"
|
||||
)
|
||||
|
||||
// some presets for building nftable rules
|
||||
var (
|
||||
zeroXor = binaryutil.NativeEndian.PutUint32(0)
|
||||
zeroXor = binaryutil.NativeEndian.PutUint32(0)
|
||||
zeroXor6 = []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}
|
||||
|
||||
exprCounterAccept = []expr.Any{
|
||||
&expr.Counter{},
|
||||
@@ -41,48 +40,69 @@ var (
|
||||
)
|
||||
|
||||
type router struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
conn *nftables.Conn
|
||||
workTable *nftables.Table
|
||||
filterTable *nftables.Table
|
||||
chains map[string]*nftables.Chain
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
conn *nftables.Conn
|
||||
workTable *nftables.Table
|
||||
workTable6 *nftables.Table
|
||||
filterTable *nftables.Table
|
||||
filterTable6 *nftables.Table
|
||||
chains map[string]*nftables.Chain
|
||||
chains6 map[string]*nftables.Chain
|
||||
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
|
||||
rules map[string]*nftables.Rule
|
||||
isDefaultFwdRulesEnabled bool
|
||||
rules map[string]*nftables.Rule
|
||||
rules6 map[string]*nftables.Rule
|
||||
isDefaultFwdRulesEnabled bool
|
||||
isDefaultFwdRulesEnabled6 bool
|
||||
}
|
||||
|
||||
func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) {
|
||||
func newRouter(parentCtx context.Context, workTable *nftables.Table, workTable6 *nftables.Table) (*router, error) {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
|
||||
r := &router{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
workTable: workTable,
|
||||
chains: make(map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
workTable: workTable,
|
||||
workTable6: workTable6,
|
||||
chains: make(map[string]*nftables.Chain),
|
||||
chains6: make(map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
rules6: make(map[string]*nftables.Rule),
|
||||
}
|
||||
|
||||
var err error
|
||||
r.filterTable, err = r.loadFilterTable()
|
||||
r.filterTable, r.filterTable6, err = r.loadFilterTables()
|
||||
if err != nil {
|
||||
if errors.Is(err, errFilterTableNotFound) {
|
||||
log.Warnf("table 'filter' not found for forward rules")
|
||||
log.Warnf("table 'filter' not found for forward rules for one of the supported address families-")
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = r.cleanUpDefaultForwardRules()
|
||||
err = r.cleanUpDefaultForwardRules(false)
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
err = r.createContainers()
|
||||
err = r.cleanUpDefaultForwardRules(true)
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from IPv6 FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
err = r.createContainers(false)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create containers for route: %s", err)
|
||||
}
|
||||
|
||||
if r.workTable6 != nil {
|
||||
err = r.createContainers(true)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create v6 containers for route: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return r, err
|
||||
}
|
||||
|
||||
@@ -92,59 +112,99 @@ func (r *router) RouteingFwChainName() string {
|
||||
|
||||
// ResetForwardRules cleans existing nftables default forward rules from the system
|
||||
func (r *router) ResetForwardRules() {
|
||||
err := r.cleanUpDefaultForwardRules()
|
||||
err := r.cleanUpDefaultForwardRules(false)
|
||||
if err != nil {
|
||||
log.Errorf("failed to reset forward rules: %s", err)
|
||||
}
|
||||
err = r.cleanUpDefaultForwardRules(true)
|
||||
if err != nil {
|
||||
log.Errorf("failed to reset forward rules: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
func (r *router) RestoreAfterV6Reset(newWorktable6 *nftables.Table) error {
|
||||
r.workTable6 = newWorktable6
|
||||
if newWorktable6 != nil {
|
||||
|
||||
err := r.cleanUpDefaultForwardRules(true)
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from IPv6 FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
err = r.createContainers(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, rule := range r.rules6 {
|
||||
rule = &nftables.Rule{
|
||||
Table: r.workTable6,
|
||||
Chain: r.chains6[rule.Chain.Name],
|
||||
Exprs: rule.Exprs,
|
||||
UserData: rule.UserData,
|
||||
}
|
||||
r.rules6[name] = r.conn.AddRule(rule)
|
||||
}
|
||||
}
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
func (r *router) loadFilterTables() (*nftables.Table, *nftables.Table, error) {
|
||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
|
||||
return nil, nil, fmt.Errorf("nftables: unable to list tables: %v", err)
|
||||
}
|
||||
|
||||
var table4 *nftables.Table = nil
|
||||
for _, table := range tables {
|
||||
if table.Name == "filter" {
|
||||
return table, nil
|
||||
table4 = table
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errFilterTableNotFound
|
||||
var table6 *nftables.Table = nil
|
||||
tables, err = r.conn.ListTablesOfFamily(nftables.TableFamilyIPv6)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("nftables: unable to list tables: %v", err)
|
||||
}
|
||||
for _, table := range tables {
|
||||
if table.Name == "filter" {
|
||||
table6 = table
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
err = nil
|
||||
if table4 == nil || table6 == nil {
|
||||
err = errFilterTableNotFound
|
||||
}
|
||||
|
||||
return table4, table6, err
|
||||
}
|
||||
|
||||
func (r *router) createContainers() error {
|
||||
func (r *router) createContainers(forV6 bool) error {
|
||||
workTable := r.workTable
|
||||
chainStorage := r.chains
|
||||
if forV6 {
|
||||
workTable = r.workTable6
|
||||
chainStorage = r.chains6
|
||||
}
|
||||
|
||||
r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{
|
||||
chainStorage[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRouteingFw,
|
||||
Table: r.workTable,
|
||||
Table: workTable,
|
||||
})
|
||||
|
||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||
chainStorage[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingNat,
|
||||
Table: r.workTable,
|
||||
Table: workTable,
|
||||
Hooknum: nftables.ChainHookPostrouting,
|
||||
Priority: nftables.ChainPriorityNATSource - 1,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
// Add RETURN rule for loopback interface
|
||||
loRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte(loopbackInterface),
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictReturn},
|
||||
},
|
||||
}
|
||||
r.conn.InsertRule(loRule)
|
||||
|
||||
err := r.refreshRulesMap()
|
||||
err := r.refreshRulesMap(forV6)
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
@@ -156,34 +216,44 @@ func (r *router) createContainers() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain
|
||||
func (r *router) AddRoutingRules(pair manager.RouterPair) error {
|
||||
err := r.refreshRulesMap()
|
||||
// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
|
||||
func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
|
||||
parsedIp, _, _ := net.ParseCIDR(pair.Source)
|
||||
|
||||
if parsedIp.To4() == nil && r.workTable6 == nil {
|
||||
return fmt.Errorf("nftables: attempted to add IPv6 routing rule even though IPv6 is not enabled for this host")
|
||||
}
|
||||
|
||||
err := r.refreshRulesMap(parsedIp.To4() == nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
|
||||
err = r.insertRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
|
||||
err = r.insertRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if pair.Masquerade {
|
||||
err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
|
||||
err = r.insertRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
|
||||
err = r.insertRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if r.filterTable != nil && !r.isDefaultFwdRulesEnabled {
|
||||
filterTable := r.filterTable
|
||||
if parsedIp.To4() == nil {
|
||||
filterTable = r.filterTable6
|
||||
}
|
||||
if filterTable != nil && !r.isDefaultFwdRulesEnabled {
|
||||
log.Debugf("add default accept forward rule")
|
||||
r.acceptForwardRule(pair.Source)
|
||||
}
|
||||
@@ -195,8 +265,8 @@ func (r *router) AddRoutingRules(pair manager.RouterPair) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// addRoutingRule inserts a nftable rule to the conn client flush queue
|
||||
func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
|
||||
// insertRoutingRule inserts a nftable rule to the conn client flush queue
|
||||
func (r *router) insertRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
|
||||
@@ -209,7 +279,13 @@ func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPai
|
||||
|
||||
ruleKey := manager.GenKey(format, pair.ID)
|
||||
|
||||
_, exists := r.rules[ruleKey]
|
||||
parsedIp, _, _ := net.ParseCIDR(pair.Source)
|
||||
rules := r.rules
|
||||
if parsedIp.To4() == nil {
|
||||
rules = r.rules6
|
||||
}
|
||||
|
||||
_, exists := rules[ruleKey]
|
||||
if exists {
|
||||
err := r.removeRoutingRule(format, pair)
|
||||
if err != nil {
|
||||
@@ -217,18 +293,35 @@ func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPai
|
||||
}
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainName],
|
||||
table, chain := r.workTable, r.chains[chainName]
|
||||
if parsedIp.To4() == nil {
|
||||
table, chain = r.workTable6, r.chains6[chainName]
|
||||
}
|
||||
|
||||
newRule := r.conn.InsertRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: expression,
|
||||
UserData: []byte(ruleKey),
|
||||
})
|
||||
|
||||
if parsedIp.To4() == nil {
|
||||
r.rules[ruleKey] = newRule
|
||||
} else {
|
||||
r.rules6[ruleKey] = newRule
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) acceptForwardRule(sourceNetwork string) {
|
||||
src := generateCIDRMatcherExpressions(true, sourceNetwork)
|
||||
dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0")
|
||||
table := r.filterTable
|
||||
parsedIp, _, _ := net.ParseCIDR(sourceNetwork)
|
||||
if parsedIp.To4() == nil {
|
||||
dst = generateCIDRMatcherExpressions(false, "::/0")
|
||||
table = r.filterTable6
|
||||
}
|
||||
|
||||
var exprs []expr.Any
|
||||
exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic
|
||||
@@ -236,10 +329,10 @@ func (r *router) acceptForwardRule(sourceNetwork string) {
|
||||
})...)
|
||||
|
||||
rule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Table: table,
|
||||
Chain: &nftables.Chain{
|
||||
Name: "FORWARD",
|
||||
Table: r.filterTable,
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
@@ -251,6 +344,9 @@ func (r *router) acceptForwardRule(sourceNetwork string) {
|
||||
r.conn.AddRule(rule)
|
||||
|
||||
src = generateCIDRMatcherExpressions(true, "0.0.0.0/0")
|
||||
if parsedIp.To4() == nil {
|
||||
src = generateCIDRMatcherExpressions(true, "::/0")
|
||||
}
|
||||
dst = generateCIDRMatcherExpressions(false, sourceNetwork)
|
||||
|
||||
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
|
||||
@@ -258,10 +354,10 @@ func (r *router) acceptForwardRule(sourceNetwork string) {
|
||||
})...)
|
||||
|
||||
rule = &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Table: table,
|
||||
Chain: &nftables.Chain{
|
||||
Name: "FORWARD",
|
||||
Table: r.filterTable,
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
@@ -270,12 +366,21 @@ func (r *router) acceptForwardRule(sourceNetwork string) {
|
||||
UserData: []byte(userDataAcceptForwardRuleDst),
|
||||
}
|
||||
r.conn.AddRule(rule)
|
||||
r.isDefaultFwdRulesEnabled = true
|
||||
if parsedIp.To4() == nil {
|
||||
r.isDefaultFwdRulesEnabled6 = true
|
||||
} else {
|
||||
r.isDefaultFwdRulesEnabled = true
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
|
||||
func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
|
||||
err := r.refreshRulesMap()
|
||||
parsedIp, _, _ := net.ParseCIDR(pair.Source)
|
||||
if parsedIp.To4() == nil && r.workTable6 == nil {
|
||||
return fmt.Errorf("nftables: attempted to remove IPv6 routing rule even though IPv6 is not enabled for this host")
|
||||
}
|
||||
|
||||
err := r.refreshRulesMap(parsedIp.To4() == nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -300,8 +405,12 @@ func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(r.rules) == 0 {
|
||||
err := r.cleanUpDefaultForwardRules()
|
||||
rulesList := r.rules
|
||||
if parsedIp.To4() == nil {
|
||||
rulesList = r.rules6
|
||||
}
|
||||
if len(rulesList) == 0 {
|
||||
err := r.cleanUpDefaultForwardRules(parsedIp.To4() == nil)
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
@@ -319,7 +428,13 @@ func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
|
||||
func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error {
|
||||
ruleKey := manager.GenKey(format, pair.ID)
|
||||
|
||||
rule, found := r.rules[ruleKey]
|
||||
parsedIp, _, _ := net.ParseCIDR(pair.Source)
|
||||
rules := r.rules
|
||||
if parsedIp.To4() == nil {
|
||||
rules = r.rules6
|
||||
}
|
||||
|
||||
rule, found := rules[ruleKey]
|
||||
if found {
|
||||
ruleType := "forwarding"
|
||||
if rule.Chain.Type == nftables.ChainTypeNAT {
|
||||
@@ -333,49 +448,68 @@ func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error
|
||||
|
||||
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
delete(rules, ruleKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
||||
// duplicates and to get missing attributes that we don't have when adding new rules
|
||||
func (r *router) refreshRulesMap() error {
|
||||
for _, chain := range r.chains {
|
||||
func (r *router) refreshRulesMap(forV6 bool) error {
|
||||
chainList := r.chains
|
||||
if forV6 {
|
||||
chainList = r.chains6
|
||||
}
|
||||
for _, chain := range chainList {
|
||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to list rules: %v", err)
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
r.rules[string(rule.UserData)] = rule
|
||||
if forV6 {
|
||||
r.rules6[string(rule.UserData)] = rule
|
||||
} else {
|
||||
r.rules[string(rule.UserData)] = rule
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) cleanUpDefaultForwardRules() error {
|
||||
if r.filterTable == nil {
|
||||
r.isDefaultFwdRulesEnabled = false
|
||||
func (r *router) cleanUpDefaultForwardRules(forV6 bool) error {
|
||||
tableFamily := nftables.TableFamilyIPv4
|
||||
filterTable := r.filterTable
|
||||
if forV6 {
|
||||
tableFamily = nftables.TableFamilyIPv6
|
||||
filterTable = r.filterTable6
|
||||
}
|
||||
|
||||
if filterTable == nil {
|
||||
if forV6 {
|
||||
r.isDefaultFwdRulesEnabled6 = false
|
||||
} else {
|
||||
r.isDefaultFwdRulesEnabled = false
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
chains, err := r.conn.ListChainsOfTableFamily(tableFamily)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var rules []*nftables.Rule
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name != r.filterTable.Name {
|
||||
if chain.Table.Name != filterTable.Name {
|
||||
continue
|
||||
}
|
||||
if chain.Name != "FORWARD" {
|
||||
continue
|
||||
}
|
||||
|
||||
rules, err = r.conn.GetRules(r.filterTable, chain)
|
||||
rules, err = r.conn.GetRules(filterTable, chain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -389,7 +523,12 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
}
|
||||
}
|
||||
}
|
||||
r.isDefaultFwdRulesEnabled = false
|
||||
|
||||
if forV6 {
|
||||
r.isDefaultFwdRulesEnabled6 = false
|
||||
} else {
|
||||
r.isDefaultFwdRulesEnabled = false
|
||||
}
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
@@ -405,6 +544,18 @@ func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any {
|
||||
} else {
|
||||
offSet = 16 // dst offset
|
||||
}
|
||||
addrLen := uint32(4)
|
||||
zeroXor := zeroXor
|
||||
|
||||
if ip.To4() == nil {
|
||||
if source {
|
||||
offSet = 8 // src offset
|
||||
} else {
|
||||
offSet = 24 // dst offset
|
||||
}
|
||||
addrLen = 16
|
||||
zeroXor = zeroXor6
|
||||
}
|
||||
|
||||
return []expr.Any{
|
||||
// fetch src add
|
||||
@@ -412,13 +563,13 @@ func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any {
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: offSet,
|
||||
Len: 4,
|
||||
Len: addrLen,
|
||||
},
|
||||
// net mask
|
||||
&expr.Bitwise{
|
||||
DestRegister: 1,
|
||||
SourceRegister: 1,
|
||||
Len: 4,
|
||||
Len: addrLen,
|
||||
Mask: network.Mask,
|
||||
Xor: zeroXor,
|
||||
},
|
||||
|
||||
@@ -4,6 +4,7 @@ package nftables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
@@ -29,16 +30,19 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
||||
t.Skip("nftables not supported on this OS")
|
||||
}
|
||||
|
||||
table, err := createWorkTable()
|
||||
table, table6, err := createWorkTables()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer deleteWorkTable()
|
||||
defer deleteWorkTables()
|
||||
|
||||
for _, testCase := range test.InsertRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
manager, err := newRouter(context.TODO(), table)
|
||||
if testCase.IsV6 && table6 == nil {
|
||||
t.Skip("Environment does not support IPv6, skipping IPv6 test...")
|
||||
}
|
||||
manager, err := newRouter(context.TODO(), table, table6)
|
||||
require.NoError(t, err, "failed to create router")
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
@@ -47,7 +51,7 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
||||
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.AddRoutingRules(testCase.InputPair)
|
||||
err = manager.InsertRoutingRules(testCase.InputPair)
|
||||
defer func() {
|
||||
_ = manager.RemoveRoutingRules(testCase.InputPair)
|
||||
}()
|
||||
@@ -58,8 +62,13 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
||||
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
|
||||
fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||
|
||||
chains := manager.chains
|
||||
if testCase.IsV6 {
|
||||
chains = manager.chains6
|
||||
}
|
||||
|
||||
found := 0
|
||||
for _, chain := range manager.chains {
|
||||
for _, chain := range chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
@@ -75,7 +84,7 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
||||
if testCase.InputPair.Masquerade {
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||
found := 0
|
||||
for _, chain := range manager.chains {
|
||||
for _, chain := range chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
@@ -94,7 +103,7 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
||||
inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||
|
||||
found = 0
|
||||
for _, chain := range manager.chains {
|
||||
for _, chain := range chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
@@ -110,7 +119,7 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
||||
if testCase.InputPair.Masquerade {
|
||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||
found := 0
|
||||
for _, chain := range manager.chains {
|
||||
for _, chain := range chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
@@ -131,16 +140,19 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
||||
t.Skip("nftables not supported on this OS")
|
||||
}
|
||||
|
||||
table, err := createWorkTable()
|
||||
table, table6, err := createWorkTables()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer deleteWorkTable()
|
||||
defer deleteWorkTables()
|
||||
|
||||
for _, testCase := range test.RemoveRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
manager, err := newRouter(context.TODO(), table)
|
||||
if testCase.IsV6 && table6 == nil {
|
||||
t.Skip("Environment does not support IPv6, skipping IPv6 test...")
|
||||
}
|
||||
manager, err := newRouter(context.TODO(), table, table6)
|
||||
require.NoError(t, err, "failed to create router")
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
@@ -150,11 +162,18 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
|
||||
chains := manager.chains
|
||||
workTable := table
|
||||
if testCase.IsV6 {
|
||||
chains = manager.chains6
|
||||
workTable = table6
|
||||
}
|
||||
|
||||
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
|
||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
Chain: manager.chains[chainNameRouteingFw],
|
||||
Table: workTable,
|
||||
Chain: chains[chainNameRouteingFw],
|
||||
Exprs: forwardExp,
|
||||
UserData: []byte(forwardRuleKey),
|
||||
})
|
||||
@@ -163,8 +182,8 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||
|
||||
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
Chain: manager.chains[chainNameRoutingNat],
|
||||
Table: workTable,
|
||||
Chain: chains[chainNameRoutingNat],
|
||||
Exprs: natExp,
|
||||
UserData: []byte(natRuleKey),
|
||||
})
|
||||
@@ -175,8 +194,8 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
||||
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
|
||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
Chain: manager.chains[chainNameRouteingFw],
|
||||
Table: workTable,
|
||||
Chain: chains[chainNameRouteingFw],
|
||||
Exprs: forwardExp,
|
||||
UserData: []byte(inForwardRuleKey),
|
||||
})
|
||||
@@ -185,8 +204,8 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||
|
||||
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
Chain: manager.chains[chainNameRoutingNat],
|
||||
Table: workTable,
|
||||
Chain: chains[chainNameRoutingNat],
|
||||
Exprs: natExp,
|
||||
UserData: []byte(inNatRuleKey),
|
||||
})
|
||||
@@ -199,7 +218,7 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
||||
err = manager.RemoveRoutingRules(testCase.InputPair)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
for _, chain := range manager.chains {
|
||||
for _, chain := range chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
@@ -238,30 +257,39 @@ func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func createWorkTable() (*nftables.Table, error) {
|
||||
func createWorkTables() (*nftables.Table, *nftables.Table, error) {
|
||||
sConn, err := nftables.New(nftables.AsLasting())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
tables6, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for _, t := range append(tables, tables6...) {
|
||||
if t.Name == tableName {
|
||||
sConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
|
||||
var table6 *nftables.Table
|
||||
if iface.SupportsIPv6() {
|
||||
table6 = sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6})
|
||||
}
|
||||
err = sConn.Flush()
|
||||
|
||||
return table, err
|
||||
return table, table6, err
|
||||
}
|
||||
|
||||
func deleteWorkTable() {
|
||||
func deleteWorkTables() {
|
||||
sConn, err := nftables.New(nftables.AsLasting())
|
||||
if err != nil {
|
||||
return
|
||||
@@ -272,6 +300,12 @@ func deleteWorkTable() {
|
||||
return
|
||||
}
|
||||
|
||||
tables6, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
tables = append(tables, tables6...)
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == tableName {
|
||||
sConn.DelTable(t)
|
||||
|
||||
@@ -8,6 +8,7 @@ var (
|
||||
InsertRuleTestCases = []struct {
|
||||
Name string
|
||||
InputPair firewall.RouterPair
|
||||
IsV6 bool
|
||||
}{
|
||||
{
|
||||
Name: "Insert Forwarding IPV4 Rule",
|
||||
@@ -27,12 +28,32 @@ var (
|
||||
Masquerade: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Insert Forwarding IPV6 Rule",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: "2001:db8:0123:4567::1/128",
|
||||
Destination: "2001:db8:0123:abcd::/64",
|
||||
Masquerade: false,
|
||||
},
|
||||
IsV6: true,
|
||||
},
|
||||
{
|
||||
Name: "Insert Forwarding And Nat IPV6 Rules",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: "2001:db8:0123:4567::1/128",
|
||||
Destination: "2001:db8:0123:abcd::/64",
|
||||
Masquerade: true,
|
||||
},
|
||||
IsV6: true,
|
||||
},
|
||||
}
|
||||
|
||||
RemoveRuleTestCases = []struct {
|
||||
Name string
|
||||
InputPair firewall.RouterPair
|
||||
IpVersion string
|
||||
IsV6 bool
|
||||
}{
|
||||
{
|
||||
Name: "Remove Forwarding And Nat IPV4 Rules",
|
||||
@@ -43,5 +64,15 @@ var (
|
||||
Masquerade: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Remove Forwarding And Nat IPV6 Rules",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: "2001:db8:0123:4567::1/128",
|
||||
Destination: "2001:db8:0123:abcd::/64",
|
||||
Masquerade: true,
|
||||
},
|
||||
IsV6: true,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -24,6 +24,7 @@ var (
|
||||
type IFaceMapper interface {
|
||||
SetFilter(iface.PacketFilter) error
|
||||
Address() iface.WGAddress
|
||||
Address6() *iface.WGAddress
|
||||
}
|
||||
|
||||
// RuleSet is a set of rules grouped by a string key
|
||||
@@ -69,6 +70,14 @@ func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
func (m *Manager) ResetV6Firewall() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) V6Active() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func create(iface IFaceMapper) (*Manager, error) {
|
||||
m := &Manager{
|
||||
decoders: sync.Pool{
|
||||
|
||||
@@ -33,6 +33,10 @@ func (i *IFaceMock) Address() iface.WGAddress {
|
||||
return i.AddressFunc()
|
||||
}
|
||||
|
||||
func (i *IFaceMock) Address6() *iface.WGAddress {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestManagerCreate(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||
|
||||
@@ -16,9 +16,10 @@ import (
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
// Manager is a ACL rules manager
|
||||
// Manager is an ACL rules manager
|
||||
type Manager interface {
|
||||
ApplyFiltering(networkMap *mgmProto.NetworkMap)
|
||||
ResetV6Acl() error
|
||||
}
|
||||
|
||||
// DefaultManager uses firewall manager to handle
|
||||
@@ -26,16 +27,36 @@ type DefaultManager struct {
|
||||
firewall firewall.Manager
|
||||
ipsetCounter int
|
||||
rulesPairs map[string][]firewall.Rule
|
||||
rulesPairs6 map[string][]firewall.Rule
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
|
||||
return &DefaultManager{
|
||||
firewall: fm,
|
||||
rulesPairs: make(map[string][]firewall.Rule),
|
||||
firewall: fm,
|
||||
rulesPairs: make(map[string][]firewall.Rule),
|
||||
rulesPairs6: make(map[string][]firewall.Rule),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultManager) ResetV6Acl() error {
|
||||
for _, rules := range d.rulesPairs6 {
|
||||
for _, r := range rules {
|
||||
err := d.firewall.DeleteRule(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
err := d.firewall.ResetV6Firewall()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.rulesPairs6 = make(map[string][]firewall.Rule)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
|
||||
//
|
||||
// If allowByDefault is true it appends allow ALL traffic rules to input and output chains.
|
||||
@@ -83,6 +104,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
if enableSSH {
|
||||
rules = append(rules, &mgmProto.FirewallRule{
|
||||
PeerIP: "0.0.0.0",
|
||||
PeerIP6: "::",
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_TCP,
|
||||
@@ -97,12 +119,14 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
rules = append(rules,
|
||||
&mgmProto.FirewallRule{
|
||||
PeerIP: "0.0.0.0",
|
||||
PeerIP6: "::",
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
&mgmProto.FirewallRule{
|
||||
PeerIP: "0.0.0.0",
|
||||
PeerIP6: "::",
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
@@ -111,6 +135,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
}
|
||||
|
||||
newRulePairs := make(map[string][]firewall.Rule)
|
||||
newRulePairs6 := make(map[string][]firewall.Rule)
|
||||
ipsetByRuleSelectors := make(map[string]string)
|
||||
|
||||
for _, r := range rules {
|
||||
@@ -123,7 +148,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
ipsetName = fmt.Sprintf("nb%07d", d.ipsetCounter)
|
||||
ipsetByRuleSelectors[selector] = ipsetName
|
||||
}
|
||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
||||
pairID, rulePair, rulePair6, err := d.protoRuleToFirewallRule(r, ipsetName)
|
||||
if err != nil {
|
||||
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
||||
d.rollBack(newRulePairs)
|
||||
@@ -132,6 +157,8 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
if len(rules) > 0 {
|
||||
d.rulesPairs[pairID] = rulePair
|
||||
newRulePairs[pairID] = rulePair
|
||||
d.rulesPairs6[pairID] = rulePair6
|
||||
newRulePairs6[pairID] = rulePair6
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,59 +173,104 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
delete(d.rulesPairs, pairID)
|
||||
}
|
||||
}
|
||||
for pairID, rules := range d.rulesPairs6 {
|
||||
if _, ok := newRulePairs6[pairID]; !ok {
|
||||
for _, rule := range rules {
|
||||
if err := d.firewall.DeleteRule(rule); err != nil {
|
||||
log.Errorf("failed to delete firewall rule: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
delete(d.rulesPairs6, pairID)
|
||||
}
|
||||
}
|
||||
|
||||
d.rulesPairs = newRulePairs
|
||||
d.rulesPairs6 = newRulePairs6
|
||||
}
|
||||
|
||||
func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
r *mgmProto.FirewallRule,
|
||||
ipsetName string,
|
||||
) (string, []firewall.Rule, error) {
|
||||
) (string, []firewall.Rule, []firewall.Rule, error) {
|
||||
ip := net.ParseIP(r.PeerIP)
|
||||
if ip == nil {
|
||||
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
return "", nil, nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
}
|
||||
|
||||
var ip6 *net.IP = nil
|
||||
if d.firewall.V6Active() && r.PeerIP6 != "" {
|
||||
ip6tmp := net.ParseIP(r.PeerIP6)
|
||||
if ip6tmp == nil {
|
||||
return "", nil, nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
}
|
||||
ip6 = &ip6tmp
|
||||
}
|
||||
|
||||
protocol, err := convertToFirewallProtocol(r.Protocol)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
|
||||
return "", nil, nil, fmt.Errorf("skipping firewall rule: %s", err)
|
||||
}
|
||||
|
||||
action, err := convertFirewallAction(r.Action)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
|
||||
return "", nil, nil, fmt.Errorf("skipping firewall rule: %s", err)
|
||||
}
|
||||
|
||||
var port *firewall.Port
|
||||
if r.Port != "" {
|
||||
value, err := strconv.Atoi(r.Port)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("invalid port, skipping firewall rule")
|
||||
return "", nil, nil, fmt.Errorf("invalid port, skipping firewall rule")
|
||||
}
|
||||
port = &firewall.Port{
|
||||
Values: []int{value},
|
||||
}
|
||||
}
|
||||
|
||||
ruleID := d.getRuleID(ip, protocol, int(r.Direction), port, action, "")
|
||||
var rules []firewall.Rule
|
||||
var rules6 []firewall.Rule
|
||||
|
||||
ruleID := d.getRuleID(ip, ip6, protocol, int(r.Direction), port, action, "")
|
||||
if rulesPair, ok := d.rulesPairs[ruleID]; ok {
|
||||
return ruleID, rulesPair, nil
|
||||
rules = rulesPair
|
||||
}
|
||||
if rulesPair6, ok := d.rulesPairs6[ruleID]; d.firewall.V6Active() && ok && ip6 != nil {
|
||||
rules6 = rulesPair6
|
||||
}
|
||||
|
||||
var rules []firewall.Rule
|
||||
switch r.Direction {
|
||||
case mgmProto.FirewallRule_IN:
|
||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
||||
case mgmProto.FirewallRule_OUT:
|
||||
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
||||
default:
|
||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||
if rules == nil {
|
||||
switch r.Direction {
|
||||
case mgmProto.FirewallRule_IN:
|
||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
||||
case mgmProto.FirewallRule_OUT:
|
||||
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
||||
default:
|
||||
return "", nil, nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
return "", nil, nil, err
|
||||
}
|
||||
|
||||
return ruleID, rules, nil
|
||||
if d.firewall.V6Active() && ip6 != nil && rules6 == nil {
|
||||
switch r.Direction {
|
||||
case mgmProto.FirewallRule_IN:
|
||||
rules6, err = d.addInRules(*ip6, protocol, port, action, ipsetName, "")
|
||||
case mgmProto.FirewallRule_OUT:
|
||||
rules6, err = d.addOutRules(*ip6, protocol, port, action, ipsetName, "")
|
||||
default:
|
||||
return "", nil, nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if err != nil && err.Error() != "failed to add firewall rule: attempted to configure filtering for IPv6 address even though IPv6 is not active" {
|
||||
return "", rules, nil, err
|
||||
}
|
||||
|
||||
return ruleID, rules, rules6, nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addInRules(
|
||||
@@ -226,8 +298,9 @@ func (d *DefaultManager) addInRules(
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
rules = append(rules, rule...)
|
||||
|
||||
return append(rules, rule...), nil
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addOutRules(
|
||||
@@ -255,20 +328,26 @@ func (d *DefaultManager) addOutRules(
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
rules = append(rules, rule...)
|
||||
|
||||
return append(rules, rule...), nil
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
// getRuleID() returns unique ID for the rule based on its parameters.
|
||||
func (d *DefaultManager) getRuleID(
|
||||
ip net.IP,
|
||||
ip6 *net.IP,
|
||||
proto firewall.Protocol,
|
||||
direction int,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
comment string,
|
||||
) string {
|
||||
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
|
||||
ip6Str := ""
|
||||
if ip6 != nil {
|
||||
ip6Str = ip6.String()
|
||||
}
|
||||
idStr := ip.String() + ip6Str + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
|
||||
if port != nil {
|
||||
idStr += port.String()
|
||||
}
|
||||
@@ -321,6 +400,8 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
// it means that rules for that protocol was already optimized on the
|
||||
// management side
|
||||
if r.PeerIP == "0.0.0.0" {
|
||||
// I don't _think_ that IPv6 is relevant here, as any optimization that has r.PeerIP6 == "::" should also
|
||||
// implicitly have r.PeerIP == "0.0.0.0".
|
||||
squashedRules = append(squashedRules, r)
|
||||
squashedProtocols[r.Protocol] = struct{}{}
|
||||
return
|
||||
@@ -364,6 +445,7 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
// add special rule 0.0.0.0 which allows all IP's in our firewall implementations
|
||||
squashedRules = append(squashedRules, &mgmProto.FirewallRule{
|
||||
PeerIP: "0.0.0.0",
|
||||
PeerIP6: "::",
|
||||
Direction: direction,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: protocol,
|
||||
|
||||
@@ -19,6 +19,7 @@ func TestDefaultManager(t *testing.T) {
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
PeerIP6: "2001:db8::fedc:ba09:8765:0001",
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_TCP,
|
||||
@@ -26,6 +27,7 @@ func TestDefaultManager(t *testing.T) {
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
PeerIP6: "2001:db8::fedc:ba09:8765:0002",
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_DROP,
|
||||
Protocol: mgmProto.FirewallRule_UDP,
|
||||
@@ -50,6 +52,14 @@ func TestDefaultManager(t *testing.T) {
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ip6, network6, err := net.ParseCIDR("2001:db8::fedc:ba09:8765:4321/64")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse IP address: %v", err)
|
||||
}
|
||||
ifaceMock.EXPECT().Address6().Return(&iface.WGAddress{
|
||||
IP: ip6,
|
||||
Network: network6,
|
||||
}).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
|
||||
@@ -83,6 +93,7 @@ func TestDefaultManager(t *testing.T) {
|
||||
networkMap.FirewallRules,
|
||||
&mgmProto.FirewallRule{
|
||||
PeerIP: "10.93.0.3",
|
||||
PeerIP6: "2001:db8::fedc:ba09:8765:0003",
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_DROP,
|
||||
Protocol: mgmProto.FirewallRule_ICMP,
|
||||
@@ -343,6 +354,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().Address6().Return(nil).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/netbirdio/netbird/client/internal/acl (interfaces: IFaceMapper)
|
||||
// Source: ./client/firewall/iface.go
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
@@ -48,6 +48,20 @@ func (mr *MockIFaceMapperMockRecorder) Address() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Address", reflect.TypeOf((*MockIFaceMapper)(nil).Address))
|
||||
}
|
||||
|
||||
// Address6 mocks base method.
|
||||
func (m *MockIFaceMapper) Address6() *iface.WGAddress {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Address6")
|
||||
ret0, _ := ret[0].(*iface.WGAddress)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Address6 indicates an expected call of Address6.
|
||||
func (mr *MockIFaceMapperMockRecorder) Address6() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Address6", reflect.TypeOf((*MockIFaceMapper)(nil).Address6))
|
||||
}
|
||||
|
||||
// IsUserspaceBind mocks base method.
|
||||
func (m *MockIFaceMapper) IsUserspaceBind() bool {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -6,16 +6,13 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
@@ -56,7 +53,6 @@ type ConfigInput struct {
|
||||
NetworkMonitor *bool
|
||||
DisableAutoConnect *bool
|
||||
ExtraIFaceBlackList []string
|
||||
DNSRouteInterval *time.Duration
|
||||
}
|
||||
|
||||
// Config Configuration type
|
||||
@@ -68,7 +64,7 @@ type Config struct {
|
||||
AdminURL *url.URL
|
||||
WgIface string
|
||||
WgPort int
|
||||
NetworkMonitor *bool
|
||||
NetworkMonitor bool
|
||||
IFaceBlackList []string
|
||||
DisableIPv6Discovery bool
|
||||
RosenpassEnabled bool
|
||||
@@ -99,9 +95,6 @@ type Config struct {
|
||||
// DisableAutoConnect determines whether the client should not start with the service
|
||||
// it's set to false by default due to backwards compatibility
|
||||
DisableAutoConnect bool
|
||||
|
||||
// DNSRouteInterval is the interval in which the DNS routes are updated
|
||||
DNSRouteInterval time.Duration
|
||||
}
|
||||
|
||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
@@ -311,21 +304,12 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.NetworkMonitor != nil && input.NetworkMonitor != config.NetworkMonitor {
|
||||
if input.NetworkMonitor != nil && *input.NetworkMonitor != config.NetworkMonitor {
|
||||
log.Infof("switching Network Monitor to %t", *input.NetworkMonitor)
|
||||
config.NetworkMonitor = input.NetworkMonitor
|
||||
config.NetworkMonitor = *input.NetworkMonitor
|
||||
updated = true
|
||||
}
|
||||
|
||||
if config.NetworkMonitor == nil {
|
||||
// enable network monitoring by default on windows and darwin clients
|
||||
if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
|
||||
enabled := true
|
||||
config.NetworkMonitor = &enabled
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
|
||||
if input.CustomDNSAddress != nil && string(input.CustomDNSAddress) != config.CustomDNSAddress {
|
||||
log.Infof("updating custom DNS address %#v (old value %#v)",
|
||||
string(input.CustomDNSAddress), config.CustomDNSAddress)
|
||||
@@ -373,18 +357,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
|
||||
log.Infof("updating DNS route interval to %s (old value %s)",
|
||||
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
|
||||
config.DNSRouteInterval = *input.DNSRouteInterval
|
||||
updated = true
|
||||
} else if config.DNSRouteInterval == 0 {
|
||||
config.DNSRouteInterval = dynamic.DefaultInterval
|
||||
log.Infof("using default DNS route interval %s", config.DNSRouteInterval)
|
||||
updated = true
|
||||
|
||||
}
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -252,10 +252,8 @@ func (c *ConnectClient) run(
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
checks := loginResp.GetChecks()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, checks)
|
||||
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe)
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
err = c.engine.Start()
|
||||
@@ -309,25 +307,21 @@ func (c *ConnectClient) Engine() *Engine {
|
||||
|
||||
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
||||
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
|
||||
nm := false
|
||||
if config.NetworkMonitor != nil {
|
||||
nm = *config.NetworkMonitor
|
||||
}
|
||||
engineConf := &EngineConfig{
|
||||
WgIfaceName: config.WgIface,
|
||||
WgAddr: peerConfig.Address,
|
||||
WgAddr6: peerConfig.Address6,
|
||||
IFaceBlackList: config.IFaceBlackList,
|
||||
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
||||
WgPrivateKey: key,
|
||||
WgPort: config.WgPort,
|
||||
NetworkMonitor: nm,
|
||||
NetworkMonitor: config.NetworkMonitor,
|
||||
SSHKey: []byte(config.SSHKey),
|
||||
NATExternalIPs: config.NATExternalIPs,
|
||||
CustomDNSAddress: config.CustomDNSAddress,
|
||||
RosenpassEnabled: config.RosenpassEnabled,
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
DNSRouteInterval: config.DNSRouteInterval,
|
||||
}
|
||||
|
||||
if config.PreSharedKey != "" {
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
package dns
|
||||
|
||||
const (
|
||||
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
|
||||
fileUncleanShutdownManagerTypeLocation = "/var/db/netbird/manager"
|
||||
)
|
||||
@@ -1,8 +0,0 @@
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
const (
|
||||
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
|
||||
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
|
||||
)
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
@@ -108,7 +108,7 @@ func getOSDNSManagerType() (osManagerType, error) {
|
||||
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||
return networkManager, nil
|
||||
}
|
||||
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
|
||||
if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
|
||||
if checkStub() {
|
||||
return systemdManager, nil
|
||||
} else {
|
||||
@@ -116,10 +116,16 @@ func getOSDNSManagerType() (osManagerType, error) {
|
||||
}
|
||||
}
|
||||
if strings.Contains(text, "resolvconf") {
|
||||
if isSystemdResolveConfMode() {
|
||||
return systemdManager, nil
|
||||
if isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
|
||||
var value string
|
||||
err = getSystemdDbusProperty(systemdDbusResolvConfModeProperty, &value)
|
||||
if err == nil {
|
||||
if value == systemdDbusResolvConfModeForeign {
|
||||
return systemdManager, nil
|
||||
}
|
||||
}
|
||||
log.Errorf("got an error while checking systemd resolv conf mode, error: %s", err)
|
||||
}
|
||||
|
||||
return resolvConfManager, nil
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
@@ -485,7 +485,11 @@ func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord
|
||||
}
|
||||
|
||||
func getNSHostPort(ns nbdns.NameServer) string {
|
||||
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
||||
if ns.IP.Is4() {
|
||||
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
||||
} else {
|
||||
return fmt.Sprintf("[%s]:%d", ns.IP.String(), ns.Port)
|
||||
}
|
||||
}
|
||||
|
||||
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
@@ -38,9 +38,12 @@ func (w *mocWGIface) Address() iface.WGAddress {
|
||||
Network: network,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *mocWGIface) ToInterface() *net.Interface {
|
||||
panic("implement me")
|
||||
func (w *mocWGIface) Address6() *iface.WGAddress {
|
||||
ip, network, _ := net.ParseCIDR("fd00:1234:dead:beef::/64")
|
||||
return &iface.WGAddress{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *mocWGIface) GetFilter() iface.PacketFilter {
|
||||
@@ -265,7 +268,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), fmt.Sprintf("fd00:1234:dead:beef::%d/128", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -343,7 +346,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
}
|
||||
|
||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", "", 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
if err != nil {
|
||||
t.Errorf("build interface wireguard: %v", err)
|
||||
return
|
||||
@@ -599,7 +602,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||
wgIFace, err := createWgInterfaceWithBind(t)
|
||||
wgIFace, err := createWgInterfaceWithBind(t, false)
|
||||
if err != nil {
|
||||
t.Fatal("failed to initialize wg interface")
|
||||
}
|
||||
@@ -625,7 +628,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDNSPermanent_updateUpstream(t *testing.T) {
|
||||
wgIFace, err := createWgInterfaceWithBind(t)
|
||||
wgIFace, err := createWgInterfaceWithBind(t, false)
|
||||
if err != nil {
|
||||
t.Fatal("failed to initialize wg interface")
|
||||
}
|
||||
@@ -717,7 +720,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
wgIFace, err := createWgInterfaceWithBind(t)
|
||||
wgIFace, err := createWgInterfaceWithBind(t, false)
|
||||
if err != nil {
|
||||
t.Fatal("failed to initialize wg interface")
|
||||
}
|
||||
@@ -788,7 +791,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
func createWgInterfaceWithBind(t *testing.T, enableV6 bool) (*iface.WGIface, error) {
|
||||
t.Helper()
|
||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||
@@ -801,7 +804,11 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
}
|
||||
|
||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
v6Addr := ""
|
||||
if enableV6 {
|
||||
v6Addr = "fd00:1234:dead:beef::1/128"
|
||||
}
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", v6Addr, 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("build interface wireguard: %v", err)
|
||||
return nil, err
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var errNotImplemented = errors.New("not implemented")
|
||||
|
||||
func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
|
||||
return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented)
|
||||
}
|
||||
|
||||
func isSystemdResolvedRunning() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func isSystemdResolveConfMode() bool {
|
||||
return false
|
||||
}
|
||||
@@ -242,25 +242,3 @@ func getSystemdDbusProperty(property string, store any) error {
|
||||
|
||||
return v.Store(store)
|
||||
}
|
||||
|
||||
func isSystemdResolvedRunning() bool {
|
||||
return isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode)
|
||||
}
|
||||
|
||||
func isSystemdResolveConfMode() bool {
|
||||
if !isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
|
||||
return false
|
||||
}
|
||||
|
||||
var value string
|
||||
if err := getSystemdDbusProperty(systemdDbusResolvConfModeProperty, &value); err != nil {
|
||||
log.Errorf("got an error while checking systemd resolv conf mode, error: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if value == systemdDbusResolvConfModeForeign {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
@@ -14,6 +14,11 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
|
||||
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
|
||||
)
|
||||
|
||||
func CheckUncleanShutdown(wgIface string) error {
|
||||
if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
@@ -78,11 +78,6 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}()
|
||||
|
||||
log.WithField("question", r.Question[0]).Trace("received an upstream question")
|
||||
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
|
||||
if r.Extra == nil {
|
||||
r.SetEdns0(4096, false)
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
|
||||
select {
|
||||
case <-u.ctx.Done():
|
||||
|
||||
@@ -2,17 +2,12 @@
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
import "github.com/netbirdio/netbird/iface"
|
||||
|
||||
// WGIface defines subset methods of interface required for manager
|
||||
type WGIface interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
ToInterface() *net.Interface
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() iface.PacketFilter
|
||||
GetDevice() *iface.DeviceWrapper
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -29,21 +28,17 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/iface/bind"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
|
||||
@@ -63,7 +58,8 @@ type EngineConfig struct {
|
||||
WgIfaceName string
|
||||
|
||||
// WgAddr is a Wireguard local address (Netbird Network IP)
|
||||
WgAddr string
|
||||
WgAddr string
|
||||
WgAddr6 string
|
||||
|
||||
// WgPrivateKey is a Wireguard private key of our peer (it MUST never leave the machine)
|
||||
WgPrivateKey wgtypes.Key
|
||||
@@ -94,8 +90,6 @@ type EngineConfig struct {
|
||||
RosenpassPermissive bool
|
||||
|
||||
ServerSSHAllowed bool
|
||||
|
||||
DNSRouteInterval time.Duration
|
||||
}
|
||||
|
||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||
@@ -107,8 +101,8 @@ type Engine struct {
|
||||
// peerConns is a map that holds all the peers that are known to this peer
|
||||
peerConns map[string]*peer.Conn
|
||||
|
||||
beforePeerHook nbnet.AddHookFunc
|
||||
afterPeerHook nbnet.RemoveHookFunc
|
||||
beforePeerHook peer.BeforeAddPeerHookFunc
|
||||
afterPeerHook peer.AfterRemovePeerHookFunc
|
||||
|
||||
// rpManager is a Rosenpass manager
|
||||
rpManager *rosenpass.Manager
|
||||
@@ -161,9 +155,6 @@ type Engine struct {
|
||||
wgProbe *Probe
|
||||
|
||||
wgConnWorker sync.WaitGroup
|
||||
|
||||
// checks are the client-applied posture checks that need to be evaluated on the client
|
||||
checks []*mgmProto.Checks
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -181,7 +172,6 @@ func NewEngine(
|
||||
config *EngineConfig,
|
||||
mobileDep MobileDependency,
|
||||
statusRecorder *peer.Status,
|
||||
checks []*mgmProto.Checks,
|
||||
) *Engine {
|
||||
return NewEngineWithProbes(
|
||||
clientCtx,
|
||||
@@ -195,7 +185,6 @@ func NewEngine(
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
checks,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -212,7 +201,6 @@ func NewEngineWithProbes(
|
||||
signalProbe *Probe,
|
||||
relayProbe *Probe,
|
||||
wgProbe *Probe,
|
||||
checks []*mgmProto.Checks,
|
||||
) *Engine {
|
||||
|
||||
return &Engine{
|
||||
@@ -233,7 +221,6 @@ func NewEngineWithProbes(
|
||||
signalProbe: signalProbe,
|
||||
relayProbe: relayProbe,
|
||||
wgProbe: wgProbe,
|
||||
checks: checks,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -282,6 +269,8 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
|
||||
|
||||
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, e.config.WgPort)
|
||||
|
||||
wgIface, err := e.newWgIface()
|
||||
if err != nil {
|
||||
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err)
|
||||
@@ -289,9 +278,6 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
e.wgInterface = wgIface
|
||||
|
||||
userspace := e.wgInterface.IsUserspaceBind()
|
||||
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort)
|
||||
|
||||
if e.config.RosenpassEnabled {
|
||||
log.Infof("rosenpass is enabled")
|
||||
if e.config.RosenpassPermissive {
|
||||
@@ -316,7 +302,7 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
e.dnsServer = dnsServer
|
||||
|
||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, initialRoutes)
|
||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes)
|
||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to initialize route manager: %s", err)
|
||||
@@ -542,10 +528,6 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
// todo update signal
|
||||
}
|
||||
|
||||
if err := e.updateChecksIfNew(update.Checks); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if update.GetNetworkMap() != nil {
|
||||
// only apply new changes and ignore old ones
|
||||
err := e.updateNetworkMap(update.GetNetworkMap())
|
||||
@@ -553,27 +535,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateChecksIfNew updates checks if there are changes and sync new meta with management
|
||||
func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
// if checks are equal, we skip the update
|
||||
if isChecksEqual(e.checks, checks) {
|
||||
return nil
|
||||
}
|
||||
e.checks = checks
|
||||
|
||||
info, err := system.GetInfoWithChecks(e.ctx, checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = system.GetInfo(e.ctx)
|
||||
}
|
||||
|
||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||
log.Errorf("could not sync meta: error %s", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -589,8 +551,8 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
} else {
|
||||
|
||||
if sshConf.GetSshEnabled() {
|
||||
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" {
|
||||
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
|
||||
if runtime.GOOS == "windows" {
|
||||
log.Warnf("running SSH server on Windows is not supported")
|
||||
return nil
|
||||
}
|
||||
// start SSH server if it wasn't running
|
||||
@@ -642,6 +604,32 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
log.Infof("updated peer address from %s to %s", oldAddr, conf.Address)
|
||||
}
|
||||
|
||||
if e.wgInterface.Address6() == nil && conf.Address6 != "" ||
|
||||
e.wgInterface.Address6() != nil && e.wgInterface.Address6().String() != conf.Address6 {
|
||||
oldAddr := "none"
|
||||
if e.wgInterface.Address6() != nil {
|
||||
oldAddr = e.wgInterface.Address6().String()
|
||||
}
|
||||
newAddr := "none"
|
||||
if conf.Address6 != "" {
|
||||
newAddr = conf.Address6
|
||||
}
|
||||
log.Debugf("updating peer IPv6 address from %s to %s", oldAddr, newAddr)
|
||||
err := e.wgInterface.UpdateAddr6(conf.Address6)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.config.WgAddr6 = conf.Address6
|
||||
|
||||
err = e.acl.ResetV6Acl()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.routeManager.ResetV6Routes()
|
||||
log.Infof("updated peer IPv6 address from %s to %s", oldAddr, conf.Address6)
|
||||
}
|
||||
|
||||
if conf.GetSshConfig() != nil {
|
||||
err := e.updateSSH(conf.GetSshConfig())
|
||||
if err != nil {
|
||||
@@ -651,6 +639,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
|
||||
e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{
|
||||
IP: e.config.WgAddr,
|
||||
IP6: e.config.WgAddr6,
|
||||
PubKey: e.config.WgPrivateKey.PublicKey().String(),
|
||||
KernelInterface: iface.WireGuardModuleIsLoaded(),
|
||||
FQDN: conf.GetFqdn(),
|
||||
@@ -663,14 +652,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
||||
func (e *Engine) receiveManagementEvents() {
|
||||
go func() {
|
||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = system.GetInfo(e.ctx)
|
||||
}
|
||||
|
||||
// err = e.mgmClient.Sync(info, e.handleSync)
|
||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||
err := e.mgmClient.Sync(e.ctx, e.handleSync)
|
||||
if err != nil {
|
||||
// happens if management is unavailable for a long time.
|
||||
// We want to cancel the operation of the whole client
|
||||
@@ -737,20 +719,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
protoRoutes := networkMap.GetRoutes()
|
||||
if protoRoutes == nil {
|
||||
protoRoutes = []*mgmProto.Route{}
|
||||
}
|
||||
|
||||
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
|
||||
if err != nil {
|
||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||
}
|
||||
|
||||
e.clientRoutesMu.Lock()
|
||||
e.clientRoutes = clientRoutes
|
||||
e.clientRoutesMu.Unlock()
|
||||
|
||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||
|
||||
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
||||
@@ -792,6 +760,19 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
protoRoutes := networkMap.GetRoutes()
|
||||
if protoRoutes == nil {
|
||||
protoRoutes = []*mgmProto.Route{}
|
||||
}
|
||||
|
||||
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
|
||||
if err != nil {
|
||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||
}
|
||||
|
||||
e.clientRoutesMu.Lock()
|
||||
e.clientRoutes = clientRoutes
|
||||
e.clientRoutesMu.Unlock()
|
||||
|
||||
protoDNSConfig := networkMap.GetDNSConfig()
|
||||
if protoDNSConfig == nil {
|
||||
@@ -819,24 +800,15 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
||||
routes := make([]*route.Route, 0)
|
||||
for _, protoRoute := range protoRoutes {
|
||||
var prefix netip.Prefix
|
||||
if len(protoRoute.Domains) == 0 {
|
||||
var err error
|
||||
if prefix, err = netip.ParsePrefix(protoRoute.Network); err != nil {
|
||||
log.Errorf("Failed to parse prefix %s: %v", protoRoute.Network, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
_, prefix, _ := route.ParseNetwork(protoRoute.Network)
|
||||
convertedRoute := &route.Route{
|
||||
ID: route.ID(protoRoute.ID),
|
||||
Network: prefix,
|
||||
Domains: domain.FromPunycodeList(protoRoute.Domains),
|
||||
NetID: route.NetID(protoRoute.NetID),
|
||||
NetworkType: route.NetworkType(protoRoute.NetworkType),
|
||||
Peer: protoRoute.Peer,
|
||||
Metric: int(protoRoute.Metric),
|
||||
Masquerade: protoRoute.Masquerade,
|
||||
KeepRoute: protoRoute.KeepRoute,
|
||||
}
|
||||
routes = append(routes, convertedRoute)
|
||||
}
|
||||
@@ -1039,6 +1011,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
|
||||
WgConfig: wgConfig,
|
||||
LocalWgPort: e.config.WgPort,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
UserspaceBind: e.wgInterface.IsUserspaceBind(),
|
||||
RosenpassPubKey: e.getRosenpassPubKey(),
|
||||
RosenpassAddr: e.getRosenpassAddr(),
|
||||
}
|
||||
@@ -1101,6 +1074,8 @@ func (e *Engine) receiveSignalEvents() {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.RegisterProtoSupportMeta(msg.Body.GetFeaturesSupported())
|
||||
|
||||
var rosenpassPubKey []byte
|
||||
rosenpassAddr := ""
|
||||
if msg.GetBody().GetRosenpassConfig() != nil {
|
||||
@@ -1123,6 +1098,8 @@ func (e *Engine) receiveSignalEvents() {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.RegisterProtoSupportMeta(msg.GetBody().GetFeaturesSupported())
|
||||
|
||||
var rosenpassPubKey []byte
|
||||
rosenpassAddr := ""
|
||||
if msg.GetBody().GetRosenpassConfig() != nil {
|
||||
@@ -1260,8 +1237,7 @@ func (e *Engine) close() {
|
||||
}
|
||||
|
||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
||||
info := system.GetInfo(e.ctx)
|
||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
||||
netMap, err := e.mgmClient.GetNetworkMap()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -1290,7 +1266,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
|
||||
default:
|
||||
}
|
||||
|
||||
return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs, e.addrViaRoutes)
|
||||
return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgAddr6, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs)
|
||||
}
|
||||
|
||||
func (e *Engine) wgInterfaceCreate() (err error) {
|
||||
@@ -1465,15 +1441,6 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
|
||||
return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs)
|
||||
}
|
||||
|
||||
func (e *Engine) restartEngine() {
|
||||
if err := e.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
if err := e.Start(); err != nil {
|
||||
log.Errorf("Failed to start engine: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) startNetworkMonitor() {
|
||||
if !e.config.NetworkMonitor {
|
||||
log.Infof("Network monitor is disabled, not starting")
|
||||
@@ -1482,54 +1449,17 @@ func (e *Engine) startNetworkMonitor() {
|
||||
|
||||
e.networkMonitor = networkmonitor.New()
|
||||
go func() {
|
||||
var mu sync.Mutex
|
||||
var debounceTimer *time.Timer
|
||||
|
||||
// Start the network monitor with a callback, Start will block until the monitor is stopped,
|
||||
// a network change is detected, or an error occurs on start up
|
||||
err := e.networkMonitor.Start(e.ctx, func() {
|
||||
// This function is called when a network change is detected
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if debounceTimer != nil {
|
||||
debounceTimer.Stop()
|
||||
log.Infof("Network monitor detected network change, restarting engine")
|
||||
if err := e.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
if err := e.Start(); err != nil {
|
||||
log.Errorf("Failed to start engine: %v", err)
|
||||
}
|
||||
|
||||
// Set a new timer to debounce rapid network changes
|
||||
debounceTimer = time.AfterFunc(1*time.Second, func() {
|
||||
// This function is called after the debounce period
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
log.Infof("Network monitor detected network change, restarting engine")
|
||||
e.restartEngine()
|
||||
})
|
||||
})
|
||||
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
|
||||
log.Errorf("Network monitor: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
|
||||
var vpnRoutes []netip.Prefix
|
||||
for _, routes := range e.GetClientRoutes() {
|
||||
if len(routes) > 0 && routes[0] != nil {
|
||||
vpnRoutes = append(vpnRoutes, routes[0].Network)
|
||||
}
|
||||
}
|
||||
|
||||
if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn {
|
||||
return true, prefix, nil
|
||||
}
|
||||
|
||||
return false, netip.Prefix{}, nil
|
||||
}
|
||||
|
||||
// isChecksEqual checks if two slices of checks are equal.
|
||||
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
||||
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
|
||||
return slices.Equal(checks.Files, oChecks.Files)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
@@ -58,9 +57,9 @@ var (
|
||||
)
|
||||
|
||||
func TestEngine_SSH(t *testing.T) {
|
||||
// todo resolve test execution on freebsd
|
||||
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" {
|
||||
t.Skip("skipping TestEngine_SSH")
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping TestEngine_SSH on Windows")
|
||||
}
|
||||
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
@@ -78,7 +77,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
ServerSSHAllowed: true,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
@@ -174,7 +173,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// time.Sleep(250 * time.Millisecond)
|
||||
//time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=")
|
||||
|
||||
@@ -212,16 +211,16 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", "", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, nil)
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder, nil)
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
@@ -394,7 +393,7 @@ func TestEngine_Sync(t *testing.T) {
|
||||
// feed updates to Engine via mocked Management client
|
||||
updates := make(chan *mgmtProto.SyncResponse)
|
||||
defer close(updates)
|
||||
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||
syncFunc := func(ctx context.Context, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||
for msg := range updates {
|
||||
err := msgHandler(msg)
|
||||
if err != nil {
|
||||
@@ -409,7 +408,7 @@ func TestEngine_Sync(t *testing.T) {
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
engine.ctx = ctx
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
@@ -566,15 +565,16 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
|
||||
WgIfaceName: wgIfaceName,
|
||||
WgAddr: wgAddr,
|
||||
WgAddr6: "",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
engine.ctx = ctx
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgAddr6, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
input := struct {
|
||||
inputSerial uint64
|
||||
@@ -736,16 +736,17 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
|
||||
WgIfaceName: wgIfaceName,
|
||||
WgAddr: wgAddr,
|
||||
WgAddr6: "",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
engine.ctx = ctx
|
||||
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgAddr6, 33100, key.String(), iface.DefaultMTU, newNet, nil)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
mockRouteManager := &routemanager.MockManager{
|
||||
@@ -811,13 +812,13 @@ func TestEngine_MultiplePeers(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
sigServer, signalAddr, err := startSignal(t)
|
||||
sigServer, signalAddr, err := startSignal()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer sigServer.Stop()
|
||||
mgmtServer, mgmtAddr, err := startManagement(t, dir)
|
||||
mgmtServer, mgmtAddr, err := startManagement(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -1009,14 +1010,12 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
WgPort: wgPort,
|
||||
}
|
||||
|
||||
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
|
||||
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm")), nil
|
||||
e.ctx = ctx
|
||||
return e, err
|
||||
}
|
||||
|
||||
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
func startSignal() (*grpc.Server, string, error) {
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
@@ -1024,9 +1023,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
|
||||
srv, err := signalServer.NewServer(otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
proto.RegisterSignalExchangeServer(s, srv)
|
||||
proto.RegisterSignalExchangeServer(s, signalServer.NewServer())
|
||||
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
@@ -1037,9 +1034,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
func startManagement(dataDir string) (*grpc.Server, string, error) {
|
||||
config := &server.Config{
|
||||
Stuns: []*server.Host{},
|
||||
TURNConfig: &server.TURNConfig{},
|
||||
@@ -1056,25 +1051,23 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
store, _, err := server.NewTestStoreFromJson(config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
ia, _ := integrations.NewIntegratedValidator(eventStore)
|
||||
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ package networkmonitor
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
@@ -12,10 +14,10 @@ import (
|
||||
"golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
)
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
|
||||
func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthopv6 netip.Addr, intfv6 *net.Interface, callback func()) error {
|
||||
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open routing socket: %v", err)
|
||||
@@ -45,6 +47,24 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||
|
||||
switch msg.Type {
|
||||
|
||||
// handle interface state changes
|
||||
case unix.RTM_IFINFO:
|
||||
ifinfo, err := parseInterfaceMessage(buf[:n])
|
||||
if err != nil {
|
||||
log.Errorf("Network monitor: error parsing interface message: %v", err)
|
||||
continue
|
||||
}
|
||||
if msg.Flags&unix.IFF_UP != 0 {
|
||||
continue
|
||||
}
|
||||
if (intfv4 == nil || ifinfo.Index != intfv4.Index) && (intfv6 == nil || ifinfo.Index != intfv6.Index) {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Infof("Network monitor: monitored interface (%s) is down.", ifinfo.Name)
|
||||
go callback()
|
||||
|
||||
// handle route changes
|
||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||
route, err := parseRouteMessage(buf[:n])
|
||||
@@ -66,7 +86,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||
go callback()
|
||||
case unix.RTM_DELETE:
|
||||
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
|
||||
if intfv4 != nil && route.Gw.Compare(nexthopv4) == 0 || intfv6 != nil && route.Gw.Compare(nexthopv6) == 0 {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||
go callback()
|
||||
}
|
||||
@@ -76,7 +96,25 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
}
|
||||
}
|
||||
|
||||
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
||||
func parseInterfaceMessage(buf []byte) (*route.InterfaceMessage, error) {
|
||||
msgs, err := route.ParseRIB(route.RIBTypeInterface, buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||
}
|
||||
|
||||
if len(msgs) != 1 {
|
||||
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||
}
|
||||
|
||||
msg, ok := msgs[0].(*route.InterfaceMessage)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func parseRouteMessage(buf []byte) (*routemanager.Route, error) {
|
||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||
@@ -91,5 +129,5 @@ func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
||||
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||
}
|
||||
|
||||
return systemops.MsgToRoute(msg)
|
||||
return routemanager.MsgToRoute(msg)
|
||||
}
|
||||
|
||||
@@ -6,13 +6,14 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
)
|
||||
|
||||
// Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns.
|
||||
@@ -28,22 +29,23 @@ func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error
|
||||
nw.wg.Add(1)
|
||||
defer nw.wg.Done()
|
||||
|
||||
var nexthop4, nexthop6 systemops.Nexthop
|
||||
var nexthop4, nexthop6 netip.Addr
|
||||
var intf4, intf6 *net.Interface
|
||||
|
||||
operation := func() error {
|
||||
var errv4, errv6 error
|
||||
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
|
||||
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
|
||||
nexthop4, intf4, errv4 = routemanager.GetNextHop(netip.IPv4Unspecified())
|
||||
nexthop6, intf6, errv6 = routemanager.GetNextHop(netip.IPv6Unspecified())
|
||||
|
||||
if errv4 != nil && errv6 != nil {
|
||||
return errors.New("failed to get default next hops")
|
||||
}
|
||||
|
||||
if errv4 == nil {
|
||||
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
|
||||
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4, intf4.Name)
|
||||
}
|
||||
if errv6 == nil {
|
||||
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
|
||||
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6, intf6.Name)
|
||||
}
|
||||
|
||||
// continue if either route was found
|
||||
@@ -63,7 +65,7 @@ func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error
|
||||
}
|
||||
}()
|
||||
|
||||
if err := checkChange(ctx, nexthop4, nexthop6, callback); err != nil {
|
||||
if err := checkChange(ctx, nexthop4, intf4, nexthop6, intf6, callback); err != nil {
|
||||
return fmt.Errorf("check change: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -6,22 +6,27 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
|
||||
if nexthopv4.Intf == nil && nexthopv6.Intf == nil {
|
||||
func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthop6 netip.Addr, intfv6 *net.Interface, callback func()) error {
|
||||
if intfv4 == nil && intfv6 == nil {
|
||||
return errors.New("no interfaces available")
|
||||
}
|
||||
|
||||
linkChan := make(chan netlink.LinkUpdate)
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
|
||||
if err := netlink.LinkSubscribe(linkChan, done); err != nil {
|
||||
return fmt.Errorf("subscribe to link updates: %v", err)
|
||||
}
|
||||
|
||||
routeChan := make(chan netlink.RouteUpdate)
|
||||
if err := netlink.RouteSubscribe(routeChan, done); err != nil {
|
||||
return fmt.Errorf("subscribe to route updates: %v", err)
|
||||
@@ -33,6 +38,25 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
case <-ctx.Done():
|
||||
return ErrStopped
|
||||
|
||||
// handle interface state changes
|
||||
case update := <-linkChan:
|
||||
if (intfv4 == nil || update.Index != int32(intfv4.Index)) && (intfv6 == nil || update.Index != int32(intfv6.Index)) {
|
||||
continue
|
||||
}
|
||||
|
||||
switch update.Header.Type {
|
||||
case syscall.RTM_DELLINK:
|
||||
log.Infof("Network monitor: monitored interface (%s) is gone", update.Link.Attrs().Name)
|
||||
go callback()
|
||||
return nil
|
||||
case syscall.RTM_NEWLINK:
|
||||
if (update.IfInfomsg.Flags&syscall.IFF_RUNNING) == 0 && update.Link.Attrs().OperState == netlink.OperDown {
|
||||
log.Infof("Network monitor: monitored interface (%s) is down.", update.Link.Attrs().Name)
|
||||
go callback()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// handle route changes
|
||||
case route := <-routeChan:
|
||||
// default route and main table
|
||||
@@ -46,7 +70,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
go callback()
|
||||
return nil
|
||||
case syscall.RTM_DELROUTE:
|
||||
if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) {
|
||||
if intfv4 != nil && route.Gw.Equal(nexthopv4.AsSlice()) || intfv6 != nil && route.Gw.Equal(nexthop6.AsSlice()) {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex)
|
||||
go callback()
|
||||
return nil
|
||||
|
||||
@@ -5,12 +5,11 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -26,16 +25,20 @@ const (
|
||||
|
||||
const interval = 10 * time.Second
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
|
||||
var neighborv4, neighborv6 *systemops.Neighbor
|
||||
func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthopv6 netip.Addr, intfv6 *net.Interface, callback func()) error {
|
||||
var neighborv4, neighborv6 *routemanager.Neighbor
|
||||
{
|
||||
initialNeighbors, err := getNeighbors()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get neighbors: %w", err)
|
||||
}
|
||||
|
||||
neighborv4 = assignNeighbor(nexthopv4, initialNeighbors)
|
||||
neighborv6 = assignNeighbor(nexthopv6, initialNeighbors)
|
||||
if n, ok := initialNeighbors[nexthopv4]; ok {
|
||||
neighborv4 = &n
|
||||
}
|
||||
if n, ok := initialNeighbors[nexthopv6]; ok {
|
||||
neighborv6 = &n
|
||||
}
|
||||
}
|
||||
log.Debugf("Network monitor: initial IPv4 neighbor: %v, IPv6 neighbor: %v", neighborv4, neighborv6)
|
||||
|
||||
@@ -47,7 +50,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
case <-ctx.Done():
|
||||
return ErrStopped
|
||||
case <-ticker.C:
|
||||
if changed(nexthopv4, neighborv4, nexthopv6, neighborv6) {
|
||||
if changed(nexthopv4, intfv4, neighborv4, nexthopv6, intfv6, neighborv6) {
|
||||
go callback()
|
||||
return nil
|
||||
}
|
||||
@@ -55,21 +58,13 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
}
|
||||
}
|
||||
|
||||
func assignNeighbor(nexthop systemops.Nexthop, initialNeighbors map[netip.Addr]systemops.Neighbor) *systemops.Neighbor {
|
||||
if n, ok := initialNeighbors[nexthop.IP]; ok &&
|
||||
n.State != unreachable &&
|
||||
n.State != incomplete &&
|
||||
n.State != tbd {
|
||||
return &n
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func changed(
|
||||
nexthopv4 systemops.Nexthop,
|
||||
neighborv4 *systemops.Neighbor,
|
||||
nexthopv6 systemops.Nexthop,
|
||||
neighborv6 *systemops.Neighbor,
|
||||
nexthopv4 netip.Addr,
|
||||
intfv4 *net.Interface,
|
||||
neighborv4 *routemanager.Neighbor,
|
||||
nexthopv6 netip.Addr,
|
||||
intfv6 *net.Interface,
|
||||
neighborv6 *routemanager.Neighbor,
|
||||
) bool {
|
||||
neighbors, err := getNeighbors()
|
||||
if err != nil {
|
||||
@@ -86,7 +81,7 @@ func changed(
|
||||
return false
|
||||
}
|
||||
|
||||
if routeChanged(nexthopv4, nexthopv4.Intf, routes) || routeChanged(nexthopv6, nexthopv6.Intf, routes) {
|
||||
if routeChanged(nexthopv4, intfv4, routes) || routeChanged(nexthopv6, intfv6, routes) {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -94,74 +89,44 @@ func changed(
|
||||
}
|
||||
|
||||
// routeChanged checks if the default routes still point to our nexthop/interface
|
||||
func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route) bool {
|
||||
if !nexthop.IP.IsValid() {
|
||||
func routeChanged(nexthop netip.Addr, intf *net.Interface, routes map[netip.Prefix]routemanager.Route) bool {
|
||||
if !nexthop.IsValid() {
|
||||
return false
|
||||
}
|
||||
|
||||
unspec := getUnspecifiedPrefix(nexthop.IP)
|
||||
defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec)
|
||||
var unspec netip.Prefix
|
||||
if nexthop.Is6() {
|
||||
unspec = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
} else {
|
||||
unspec = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
}
|
||||
|
||||
log.Tracef("network monitor: all default routes:\n%s", strings.Join(defaultRoutes, "\n"))
|
||||
|
||||
if !foundMatchingRoute {
|
||||
logRouteChange(nexthop.IP, intf)
|
||||
if r, ok := routes[unspec]; ok {
|
||||
if r.Nexthop != nexthop || compareIntf(r.Interface, intf) != 0 {
|
||||
intf := "<nil>"
|
||||
if r.Interface != nil {
|
||||
intf = r.Interface.Name
|
||||
}
|
||||
log.Infof("network monitor: default route changed: %s via %s (%s)", r.Destination, r.Nexthop, intf)
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
log.Infof("network monitor: default route is gone")
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
|
||||
}
|
||||
|
||||
func getUnspecifiedPrefix(ip netip.Addr) netip.Prefix {
|
||||
if ip.Is6() {
|
||||
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
}
|
||||
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
}
|
||||
|
||||
func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) {
|
||||
var defaultRoutes []string
|
||||
foundMatchingRoute := false
|
||||
|
||||
for _, r := range routes {
|
||||
if r.Destination == unspec {
|
||||
routeInfo := formatRouteInfo(r)
|
||||
defaultRoutes = append(defaultRoutes, routeInfo)
|
||||
|
||||
if r.Nexthop == nexthop.IP && compareIntf(r.Interface, intf) == 0 {
|
||||
foundMatchingRoute = true
|
||||
log.Debugf("network monitor: found matching default route: %s", routeInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return defaultRoutes, foundMatchingRoute
|
||||
}
|
||||
|
||||
func formatRouteInfo(r systemops.Route) string {
|
||||
newIntf := "<nil>"
|
||||
if r.Interface != nil {
|
||||
newIntf = r.Interface.Name
|
||||
}
|
||||
return fmt.Sprintf("Nexthop: %s, Interface: %s", r.Nexthop, newIntf)
|
||||
}
|
||||
|
||||
func logRouteChange(ip netip.Addr, intf *net.Interface) {
|
||||
oldIntf := "<nil>"
|
||||
if intf != nil {
|
||||
oldIntf = intf.Name
|
||||
}
|
||||
log.Infof("network monitor: default route for %s (%s) is gone or changed", ip, oldIntf)
|
||||
}
|
||||
|
||||
func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, neighbors map[netip.Addr]systemops.Neighbor) bool {
|
||||
func neighborChanged(nexthop netip.Addr, neighbor *routemanager.Neighbor, neighbors map[netip.Addr]routemanager.Neighbor) bool {
|
||||
if neighbor == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// TODO: consider non-local nexthops, e.g. on point-to-point interfaces
|
||||
if n, ok := neighbors[nexthop.IP]; ok {
|
||||
if n.State == unreachable || n.State == incomplete {
|
||||
if n, ok := neighbors[nexthop]; ok {
|
||||
if n.State != reachable && n.State != permanent {
|
||||
log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State))
|
||||
return true
|
||||
} else if n.InterfaceIndex != neighbor.InterfaceIndex {
|
||||
@@ -185,13 +150,13 @@ func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, ne
|
||||
return false
|
||||
}
|
||||
|
||||
func getNeighbors() (map[netip.Addr]systemops.Neighbor, error) {
|
||||
entries, err := systemops.GetNeighbors()
|
||||
func getNeighbors() (map[netip.Addr]routemanager.Neighbor, error) {
|
||||
entries, err := routemanager.GetNeighbors()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get neighbors: %w", err)
|
||||
}
|
||||
|
||||
neighbours := make(map[netip.Addr]systemops.Neighbor, len(entries))
|
||||
neighbours := make(map[netip.Addr]routemanager.Neighbor, len(entries))
|
||||
for _, entry := range entries {
|
||||
neighbours[entry.IPAddress] = entry
|
||||
}
|
||||
@@ -199,13 +164,18 @@ func getNeighbors() (map[netip.Addr]systemops.Neighbor, error) {
|
||||
return neighbours, nil
|
||||
}
|
||||
|
||||
func getRoutes() ([]systemops.Route, error) {
|
||||
entries, err := systemops.GetRoutes()
|
||||
func getRoutes() (map[netip.Prefix]routemanager.Route, error) {
|
||||
entries, err := routemanager.GetRoutes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get routes: %w", err)
|
||||
}
|
||||
|
||||
return entries, nil
|
||||
routes := make(map[netip.Prefix]routemanager.Route, len(entries))
|
||||
for _, entry := range entries {
|
||||
routes[entry.Destination] = entry
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func stateFromInt(state uint8) string {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -19,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/iface/bind"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
@@ -68,6 +70,9 @@ type ConnConfig struct {
|
||||
|
||||
NATExternalIPs []string
|
||||
|
||||
// UsesBind indicates whether the WireGuard interface is userspace and uses bind.ICEBind
|
||||
UserspaceBind bool
|
||||
|
||||
// RosenpassPubKey is this peer's Rosenpass public key
|
||||
RosenpassPubKey []byte
|
||||
// RosenpassPubKey is this peer's RosenpassAddr server address (IP:port)
|
||||
@@ -98,6 +103,9 @@ type IceCredentials struct {
|
||||
Pwd string
|
||||
}
|
||||
|
||||
type BeforeAddPeerHookFunc func(connID nbnet.ConnectionID, IP net.IP) error
|
||||
type AfterRemovePeerHookFunc func(connID nbnet.ConnectionID) error
|
||||
|
||||
type Conn struct {
|
||||
config ConnConfig
|
||||
mu sync.Mutex
|
||||
@@ -127,13 +135,30 @@ type Conn struct {
|
||||
wgProxyFactory *wgproxy.Factory
|
||||
wgProxy wgproxy.Proxy
|
||||
|
||||
remoteModeCh chan ModeMessage
|
||||
meta meta
|
||||
|
||||
adapter iface.TunAdapter
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
sentExtraSrflx bool
|
||||
|
||||
remoteEndpoint *net.UDPAddr
|
||||
remoteConn *ice.Conn
|
||||
|
||||
connID nbnet.ConnectionID
|
||||
beforeAddPeerHooks []nbnet.AddHookFunc
|
||||
afterRemovePeerHooks []nbnet.RemoveHookFunc
|
||||
beforeAddPeerHooks []BeforeAddPeerHookFunc
|
||||
afterRemovePeerHooks []AfterRemovePeerHookFunc
|
||||
}
|
||||
|
||||
// meta holds meta information about a connection
|
||||
type meta struct {
|
||||
protoSupport signal.FeaturesSupport
|
||||
}
|
||||
|
||||
// ModeMessage represents a connection mode chosen by the peer
|
||||
type ModeMessage struct {
|
||||
// Direct indicates that it decided to use a direct connection
|
||||
Direct bool
|
||||
}
|
||||
|
||||
// GetConf returns the connection config
|
||||
@@ -162,6 +187,7 @@ func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.
|
||||
remoteOffersCh: make(chan OfferAnswer),
|
||||
remoteAnswerCh: make(chan OfferAnswer),
|
||||
statusRecorder: statusRecorder,
|
||||
remoteModeCh: make(chan ModeMessage, 1),
|
||||
wgProxyFactory: wgProxyFactory,
|
||||
adapter: adapter,
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
@@ -352,6 +378,8 @@ func (conn *Conn) Open(ctx context.Context) error {
|
||||
remoteWgPort = remoteOfferAnswer.WgListenPort
|
||||
}
|
||||
|
||||
conn.remoteConn = remoteConn
|
||||
|
||||
// the ice connection has been established successfully so we are ready to start the proxy
|
||||
remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort, remoteOfferAnswer.RosenpassPubKey,
|
||||
remoteOfferAnswer.RosenpassAddr)
|
||||
@@ -376,11 +404,11 @@ func isRelayCandidate(candidate ice.Candidate) bool {
|
||||
return candidate.Type() == ice.CandidateTypeRelay
|
||||
}
|
||||
|
||||
func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) {
|
||||
func (conn *Conn) AddBeforeAddPeerHook(hook BeforeAddPeerHookFunc) {
|
||||
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
|
||||
}
|
||||
|
||||
func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) {
|
||||
func (conn *Conn) AddAfterRemovePeerHook(hook AfterRemovePeerHookFunc) {
|
||||
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
|
||||
}
|
||||
|
||||
@@ -409,6 +437,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
|
||||
}
|
||||
|
||||
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
||||
conn.remoteEndpoint = endpointUdpAddr
|
||||
log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
|
||||
|
||||
conn.connID = nbnet.GenerateConnID()
|
||||
@@ -455,7 +484,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
|
||||
log.Warnf("unable to save peer's state, got error: %v", err)
|
||||
}
|
||||
|
||||
_, ipNet, err := net.ParseCIDR(conn.config.WgConfig.AllowedIps)
|
||||
_, ipNet, err := net.ParseCIDR(strings.Split(conn.config.WgConfig.AllowedIps, ",")[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -594,39 +623,40 @@ func (conn *Conn) SetSendSignalMessage(handler func(message *sProto.Message) err
|
||||
// onICECandidate is a callback attached to an ICE Agent to receive new local connection candidates
|
||||
// and then signals them to the remote peer
|
||||
func (conn *Conn) onICECandidate(candidate ice.Candidate) {
|
||||
// nil means candidate gathering has been ended
|
||||
if candidate == nil {
|
||||
return
|
||||
if candidate != nil {
|
||||
// TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored
|
||||
log.Debugf("discovered local candidate %s", candidate.String())
|
||||
go func() {
|
||||
err := conn.signalCandidate(candidate)
|
||||
if err != nil {
|
||||
log.Errorf("failed signaling candidate to the remote peer %s %s", conn.config.Key, err)
|
||||
}
|
||||
|
||||
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
|
||||
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
|
||||
if !conn.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
|
||||
relatedAdd := candidate.RelatedAddress()
|
||||
extraSrflx, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
|
||||
Network: candidate.NetworkType().String(),
|
||||
Address: candidate.Address(),
|
||||
Port: relatedAdd.Port,
|
||||
Component: candidate.Component(),
|
||||
RelAddr: relatedAdd.Address,
|
||||
RelPort: relatedAdd.Port,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("failed creating extra server reflexive candidate %s", err)
|
||||
return
|
||||
}
|
||||
err = conn.signalCandidate(extraSrflx)
|
||||
if err != nil {
|
||||
log.Errorf("failed signaling the extra server reflexive candidate to the remote peer %s: %s", conn.config.Key, err)
|
||||
return
|
||||
}
|
||||
conn.sentExtraSrflx = true
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored
|
||||
log.Debugf("discovered local candidate %s", candidate.String())
|
||||
go func() {
|
||||
err := conn.signalCandidate(candidate)
|
||||
if err != nil {
|
||||
log.Errorf("failed signaling candidate to the remote peer %s %s", conn.config.Key, err)
|
||||
}
|
||||
}()
|
||||
|
||||
if !conn.shouldSendExtraSrflxCandidate(candidate) {
|
||||
return
|
||||
}
|
||||
|
||||
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
|
||||
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
|
||||
extraSrflx, err := extraSrflxCandidate(candidate)
|
||||
if err != nil {
|
||||
log.Errorf("failed creating extra server reflexive candidate %s", err)
|
||||
return
|
||||
}
|
||||
conn.sentExtraSrflx = true
|
||||
|
||||
go func() {
|
||||
err = conn.signalCandidate(extraSrflx)
|
||||
if err != nil {
|
||||
log.Errorf("failed signaling the extra server reflexive candidate to the remote peer %s: %s", conn.config.Key, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (conn *Conn) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
|
||||
@@ -761,6 +791,10 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMa
|
||||
return
|
||||
}
|
||||
|
||||
if candidateViaRoutes(candidate, haRoutes) {
|
||||
return
|
||||
}
|
||||
|
||||
err := conn.agent.AddRemoteCandidate(candidate)
|
||||
if err != nil {
|
||||
log.Errorf("error while handling remote candidate from peer %s", conn.config.Key)
|
||||
@@ -773,21 +807,36 @@ func (conn *Conn) GetKey() string {
|
||||
return conn.config.Key
|
||||
}
|
||||
|
||||
func (conn *Conn) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
|
||||
if !conn.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
|
||||
return true
|
||||
// RegisterProtoSupportMeta register supported proto message in the connection metadata
|
||||
func (conn *Conn) RegisterProtoSupportMeta(support []uint32) {
|
||||
protoSupport := signal.ParseFeaturesSupported(support)
|
||||
conn.meta.protoSupport = protoSupport
|
||||
}
|
||||
|
||||
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
|
||||
var routePrefixes []netip.Prefix
|
||||
for _, routes := range clientRoutes {
|
||||
if len(routes) > 0 && routes[0] != nil {
|
||||
routePrefixes = append(routePrefixes, routes[0].Network)
|
||||
}
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(candidate.Address())
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
for _, prefix := range routePrefixes {
|
||||
// default route is
|
||||
if prefix.Bits() == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if prefix.Contains(addr) {
|
||||
log.Debugf("Ignoring candidate [%s], its address is part of routed network %s", candidate.String(), prefix)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
|
||||
relatedAdd := candidate.RelatedAddress()
|
||||
return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
|
||||
Network: candidate.NetworkType().String(),
|
||||
Address: candidate.Address(),
|
||||
Port: relatedAdd.Port,
|
||||
Component: candidate.Component(),
|
||||
RelAddr: relatedAdd.Address,
|
||||
RelPort: relatedAdd.Port,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_GetKey(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
@@ -51,7 +51,7 @@ func TestConn_GetKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
@@ -88,7 +88,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
@@ -124,7 +124,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
func TestConn_Status(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
@@ -154,7 +154,7 @@ func TestConn_Status(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_Close(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
|
||||
@@ -2,17 +2,14 @@ package peer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
// State contains the latest state of a peer
|
||||
@@ -40,25 +37,25 @@ type State struct {
|
||||
// AddRoute add a single route to routes map
|
||||
func (s *State) AddRoute(network string) {
|
||||
s.Mux.Lock()
|
||||
defer s.Mux.Unlock()
|
||||
if s.routes == nil {
|
||||
s.routes = make(map[string]struct{})
|
||||
}
|
||||
s.routes[network] = struct{}{}
|
||||
s.Mux.Unlock()
|
||||
}
|
||||
|
||||
// SetRoutes set state routes
|
||||
func (s *State) SetRoutes(routes map[string]struct{}) {
|
||||
s.Mux.Lock()
|
||||
defer s.Mux.Unlock()
|
||||
s.routes = routes
|
||||
s.Mux.Unlock()
|
||||
}
|
||||
|
||||
// DeleteRoute removes a route from the network amp
|
||||
func (s *State) DeleteRoute(network string) {
|
||||
s.Mux.Lock()
|
||||
defer s.Mux.Unlock()
|
||||
delete(s.routes, network)
|
||||
s.Mux.Unlock()
|
||||
}
|
||||
|
||||
// GetRoutes return routes map
|
||||
@@ -71,6 +68,7 @@ func (s *State) GetRoutes() map[string]struct{} {
|
||||
// LocalPeerState contains the latest state of the local peer
|
||||
type LocalPeerState struct {
|
||||
IP string
|
||||
IP6 string
|
||||
PubKey string
|
||||
KernelInterface bool
|
||||
FQDN string
|
||||
@@ -120,23 +118,22 @@ type FullStatus struct {
|
||||
|
||||
// Status holds a state of peers, signal, management connections and relays
|
||||
type Status struct {
|
||||
mux sync.Mutex
|
||||
peers map[string]State
|
||||
changeNotify map[string]chan struct{}
|
||||
signalState bool
|
||||
signalError error
|
||||
managementState bool
|
||||
managementError error
|
||||
relayStates []relay.ProbeResult
|
||||
localPeer LocalPeerState
|
||||
offlinePeers []State
|
||||
mgmAddress string
|
||||
signalAddress string
|
||||
notifier *notifier
|
||||
rosenpassEnabled bool
|
||||
rosenpassPermissive bool
|
||||
nsGroupStates []NSGroupState
|
||||
resolvedDomainsStates map[domain.Domain][]netip.Prefix
|
||||
mux sync.Mutex
|
||||
peers map[string]State
|
||||
changeNotify map[string]chan struct{}
|
||||
signalState bool
|
||||
signalError error
|
||||
managementState bool
|
||||
managementError error
|
||||
relayStates []relay.ProbeResult
|
||||
localPeer LocalPeerState
|
||||
offlinePeers []State
|
||||
mgmAddress string
|
||||
signalAddress string
|
||||
notifier *notifier
|
||||
rosenpassEnabled bool
|
||||
rosenpassPermissive bool
|
||||
nsGroupStates []NSGroupState
|
||||
|
||||
// To reduce the number of notification invocation this bool will be true when need to call the notification
|
||||
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
|
||||
@@ -147,12 +144,11 @@ type Status struct {
|
||||
// NewRecorder returns a new Status instance
|
||||
func NewRecorder(mgmAddress string) *Status {
|
||||
return &Status{
|
||||
peers: make(map[string]State),
|
||||
changeNotify: make(map[string]chan struct{}),
|
||||
offlinePeers: make([]State, 0),
|
||||
notifier: newNotifier(),
|
||||
mgmAddress: mgmAddress,
|
||||
resolvedDomainsStates: make(map[domain.Domain][]netip.Prefix),
|
||||
peers: make(map[string]State),
|
||||
changeNotify: make(map[string]chan struct{}),
|
||||
offlinePeers: make([]State, 0),
|
||||
notifier: newNotifier(),
|
||||
mgmAddress: mgmAddress,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -193,7 +189,7 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
||||
|
||||
state, ok := d.peers[peerPubKey]
|
||||
if !ok {
|
||||
return State{}, iface.ErrPeerNotFound
|
||||
return State{}, errors.New("peer not found")
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
@@ -434,18 +430,6 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
|
||||
d.nsGroupStates = dnsStates
|
||||
}
|
||||
|
||||
func (d *Status) UpdateResolvedDomainsStates(domain domain.Domain, prefixes []netip.Prefix) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.resolvedDomainsStates[domain] = prefixes
|
||||
}
|
||||
|
||||
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
delete(d.resolvedDomainsStates, domain)
|
||||
}
|
||||
|
||||
func (d *Status) GetRosenpassState() RosenpassState {
|
||||
return RosenpassState{
|
||||
d.rosenpassEnabled,
|
||||
@@ -510,12 +494,6 @@ func (d *Status) GetDNSStates() []NSGroupState {
|
||||
return d.nsGroupStates
|
||||
}
|
||||
|
||||
func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
return maps.Clone(d.resolvedDomainsStates)
|
||||
}
|
||||
|
||||
// GetFullStatus gets full status
|
||||
func (d *Status) GetFullStatus() FullStatus {
|
||||
d.mux.Lock()
|
||||
|
||||
@@ -3,20 +3,19 @@ package routemanager
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const minRangeBits = 7
|
||||
|
||||
type routerPeerStatus struct {
|
||||
connected bool
|
||||
relayed bool
|
||||
@@ -29,42 +28,34 @@ type routesUpdate struct {
|
||||
routes []*route.Route
|
||||
}
|
||||
|
||||
// RouteHandler defines the interface for handling routes
|
||||
type RouteHandler interface {
|
||||
String() string
|
||||
AddRoute(ctx context.Context) error
|
||||
RemoveRoute() error
|
||||
AddAllowedIPs(peerKey string) error
|
||||
RemoveAllowedIPs() error
|
||||
}
|
||||
|
||||
type clientNetwork struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
stop context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
routes map[route.ID]*route.Route
|
||||
routeUpdate chan routesUpdate
|
||||
peerStateUpdate chan struct{}
|
||||
routePeersNotifiers map[string]chan struct{}
|
||||
currentChosen *route.Route
|
||||
handler RouteHandler
|
||||
chosenRoute *route.Route
|
||||
chosenIP *net.IP
|
||||
network netip.Prefix
|
||||
updateSerial uint64
|
||||
}
|
||||
|
||||
func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface *iface.WGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
|
||||
func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
client := &clientNetwork{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
stop: cancel,
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
routes: make(map[route.ID]*route.Route),
|
||||
routePeersNotifiers: make(map[string]chan struct{}),
|
||||
routeUpdate: make(chan routesUpdate),
|
||||
peerStateUpdate: make(chan struct{}),
|
||||
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder),
|
||||
network: network,
|
||||
}
|
||||
return client
|
||||
}
|
||||
@@ -96,8 +87,8 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
|
||||
// * Metric: Routes with lower metrics (better) are prioritized.
|
||||
// * Non-relayed: Routes without relays are preferred.
|
||||
// * Direct connections: Routes with direct peer connections are favored.
|
||||
// * Latency: Routes with lower latency are prioritized.
|
||||
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
|
||||
// * Latency: Routes with lower latency are prioritized.
|
||||
//
|
||||
// It returns the ID of the selected optimal route.
|
||||
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID {
|
||||
@@ -106,8 +97,8 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
|
||||
currScore := float64(0)
|
||||
|
||||
currID := route.ID("")
|
||||
if c.currentChosen != nil {
|
||||
currID = c.currentChosen.ID
|
||||
if c.chosenRoute != nil {
|
||||
currID = c.chosenRoute.ID
|
||||
}
|
||||
|
||||
for _, r := range c.routes {
|
||||
@@ -161,18 +152,18 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
|
||||
peers = append(peers, r.Peer)
|
||||
}
|
||||
|
||||
log.Warnf("The network [%v] has not been assigned a routing peer as no peers from the list %s are currently connected", c.handler, peers)
|
||||
log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers)
|
||||
case chosen != currID:
|
||||
// we compare the current score + 10ms to the chosen score to avoid flapping between routes
|
||||
if currScore != 0 && currScore+0.01 > chosenScore {
|
||||
log.Debugf("Keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore)
|
||||
log.Debugf("keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore)
|
||||
return currID
|
||||
}
|
||||
var p string
|
||||
if rt := c.routes[chosen]; rt != nil {
|
||||
p = rt.Peer
|
||||
}
|
||||
log.Infof("New chosen route is %s with peer %s with score %f for network [%v]", chosen, p, chosenScore, c.handler)
|
||||
log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, p, chosenScore, c.network)
|
||||
}
|
||||
|
||||
return chosen
|
||||
@@ -206,103 +197,109 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeRouteFromWireguardPeer() error {
|
||||
c.removeStateRoute()
|
||||
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
||||
state, err := c.statusRecorder.GetPeer(peerKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get peer state: %v", err)
|
||||
}
|
||||
|
||||
if err := c.handler.RemoveAllowedIPs(); err != nil {
|
||||
return fmt.Errorf("remove allowed IPs: %w", err)
|
||||
state.DeleteRoute(c.network.String())
|
||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
|
||||
if state.ConnStatus != peer.StatusConnected {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove allowed IP %s removed for peer %s, err: %v",
|
||||
c.network, c.chosenRoute.Peer, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
||||
if c.currentChosen == nil {
|
||||
return nil
|
||||
}
|
||||
if c.chosenRoute != nil {
|
||||
// TODO IPv6 (pass wgInterface)
|
||||
if err := removeVPNRoute(c.network, c.getAsInterface()); err != nil {
|
||||
return fmt.Errorf("remove route %s from system, err: %v", c.network, err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := c.removeRouteFromWireguardPeer(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
|
||||
if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil {
|
||||
return fmt.Errorf("remove route: %v", err)
|
||||
}
|
||||
}
|
||||
if err := c.handler.RemoveRoute(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove route: %w", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||
routerPeerStatuses := c.getRouterPeerStatuses()
|
||||
|
||||
newChosenID := c.getBestRouteFromStatuses(routerPeerStatuses)
|
||||
chosen := c.getBestRouteFromStatuses(routerPeerStatuses)
|
||||
|
||||
// If no route is chosen, remove the route from the peer and system
|
||||
if newChosenID == "" {
|
||||
if chosen == "" {
|
||||
if err := c.removeRouteFromPeerAndSystem(); err != nil {
|
||||
return fmt.Errorf("remove route for peer %s: %w", c.currentChosen.Peer, err)
|
||||
return fmt.Errorf("remove route from peer and system: %v", err)
|
||||
}
|
||||
|
||||
c.currentChosen = nil
|
||||
c.chosenRoute = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// If the chosen route is the same as the current route, do nothing
|
||||
if c.currentChosen != nil && c.currentChosen.ID == newChosenID &&
|
||||
c.currentChosen.IsEqual(c.routes[newChosenID]) {
|
||||
return nil
|
||||
if c.chosenRoute != nil && c.chosenRoute.ID == chosen {
|
||||
if c.chosenRoute.IsEqual(c.routes[chosen]) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if c.currentChosen == nil {
|
||||
// If they were not previously assigned to another peer, add routes to the system first
|
||||
if err := c.handler.AddRoute(c.ctx); err != nil {
|
||||
return fmt.Errorf("add route: %w", err)
|
||||
if c.chosenRoute != nil {
|
||||
// If a previous route exists, remove it from the peer
|
||||
if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil {
|
||||
return fmt.Errorf("remove route from peer: %v", err)
|
||||
}
|
||||
} else {
|
||||
// Otherwise, remove the allowed IPs from the previous peer first
|
||||
if err := c.removeRouteFromWireguardPeer(); err != nil {
|
||||
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
||||
// TODO recheck IPv6
|
||||
gwAddr := c.wgInterface.Address().IP
|
||||
c.chosenIP = &gwAddr
|
||||
if c.network.Addr().Is6() {
|
||||
if c.wgInterface.Address6() == nil {
|
||||
return fmt.Errorf("Could not assign IPv6 route %s for peer %s because no IPv6 address is assigned",
|
||||
c.network.String(), c.wgInterface.Address().IP.String())
|
||||
}
|
||||
c.chosenIP = &c.wgInterface.Address6().IP
|
||||
}
|
||||
// otherwise add the route to the system
|
||||
if err := addVPNRoute(c.network, c.getAsInterface()); err != nil {
|
||||
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
|
||||
c.network.String(), c.chosenIP.String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
c.currentChosen = c.routes[newChosenID]
|
||||
c.chosenRoute = c.routes[chosen]
|
||||
|
||||
if err := c.handler.AddAllowedIPs(c.currentChosen.Peer); err != nil {
|
||||
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
||||
state, err := c.statusRecorder.GetPeer(c.chosenRoute.Peer)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get peer state: %v", err)
|
||||
} else {
|
||||
state.AddRoute(c.network.String())
|
||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.addStateRoute()
|
||||
if err := c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()); err != nil {
|
||||
log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v",
|
||||
c.network, c.chosenRoute.Peer, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientNetwork) addStateRoute() {
|
||||
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get peer state: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
state.AddRoute(c.handler.String())
|
||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeStateRoute() {
|
||||
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get peer state: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
state.DeleteRoute(c.handler.String())
|
||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
|
||||
go func() {
|
||||
c.routeUpdate <- update
|
||||
@@ -333,23 +330,24 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
log.Debugf("Stopping watcher for network [%v]", c.handler)
|
||||
if err := c.removeRouteFromPeerAndSystem(); err != nil {
|
||||
log.Errorf("Failed to remove routes for [%v]: %v", c.handler, err)
|
||||
log.Debugf("stopping watcher for network %s", c.network)
|
||||
err := c.removeRouteFromPeerAndSystem()
|
||||
if err != nil {
|
||||
log.Errorf("Couldn't remove route from peer and system for network %s: %v", c.network, err)
|
||||
}
|
||||
return
|
||||
case <-c.peerStateUpdate:
|
||||
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
|
||||
log.Errorf("Couldn't recalculate route and update peer and system: %v", err)
|
||||
}
|
||||
case update := <-c.routeUpdate:
|
||||
if update.updateSerial < c.updateSerial {
|
||||
log.Warnf("Received a routes update with smaller serial number (%d -> %d), ignoring it", c.updateSerial, update.updateSerial)
|
||||
log.Warnf("Received a routes update with smaller serial number, ignoring it")
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("Received a new client network route update for [%v]", c.handler)
|
||||
log.Debugf("Received a new client network route update for %s", c.network)
|
||||
|
||||
c.handleUpdate(update)
|
||||
|
||||
@@ -357,7 +355,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
|
||||
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
|
||||
log.Errorf("Couldn't recalculate route and update peer and system for network %s: %v", c.network, err)
|
||||
}
|
||||
|
||||
c.startPeersStatusChangeWatcher()
|
||||
@@ -365,9 +363,14 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
}
|
||||
}
|
||||
|
||||
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status) RouteHandler {
|
||||
if rt.IsDynamic() {
|
||||
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder)
|
||||
func (c *clientNetwork) getAsInterface() *net.Interface {
|
||||
intf, err := net.InterfaceByName(c.wgInterface.Name())
|
||||
if err != nil {
|
||||
log.Warnf("Couldn't get interface by name %s: %v", c.wgInterface.Name(), err)
|
||||
intf = &net.Interface{
|
||||
Name: c.wgInterface.Name(),
|
||||
}
|
||||
}
|
||||
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
|
||||
|
||||
return intf
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -341,9 +340,9 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
|
||||
// create new clientNetwork
|
||||
client := &clientNetwork{
|
||||
handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil),
|
||||
routes: tc.existingRoutes,
|
||||
currentChosen: currentRoute,
|
||||
network: netip.MustParsePrefix("192.168.0.0/24"),
|
||||
routes: tc.existingRoutes,
|
||||
chosenRoute: currentRoute,
|
||||
}
|
||||
|
||||
chosenRoute := client.getBestRouteFromStatuses(tc.statuses)
|
||||
|
||||
@@ -1,378 +0,0 @@
|
||||
package dynamic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultInterval = time.Minute
|
||||
|
||||
minInterval = 2 * time.Second
|
||||
failureInterval = 5 * time.Second
|
||||
|
||||
addAllowedIP = "add allowed IP %s: %w"
|
||||
)
|
||||
|
||||
type domainMap map[domain.Domain][]netip.Prefix
|
||||
|
||||
type resolveResult struct {
|
||||
domain domain.Domain
|
||||
prefix netip.Prefix
|
||||
err error
|
||||
}
|
||||
|
||||
type Route struct {
|
||||
route *route.Route
|
||||
routeRefCounter *refcounter.RouteRefCounter
|
||||
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
||||
interval time.Duration
|
||||
dynamicDomains domainMap
|
||||
mu sync.Mutex
|
||||
currentPeerKey string
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
}
|
||||
|
||||
func NewRoute(
|
||||
rt *route.Route,
|
||||
routeRefCounter *refcounter.RouteRefCounter,
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
interval time.Duration,
|
||||
statusRecorder *peer.Status,
|
||||
) *Route {
|
||||
return &Route{
|
||||
route: rt,
|
||||
routeRefCounter: routeRefCounter,
|
||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||
interval: interval,
|
||||
dynamicDomains: domainMap{},
|
||||
statusRecorder: statusRecorder,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Route) String() string {
|
||||
s, err := r.route.Domains.String()
|
||||
if err != nil {
|
||||
return r.route.Domains.PunycodeString()
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (r *Route) AddRoute(ctx context.Context) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.cancel != nil {
|
||||
r.cancel()
|
||||
}
|
||||
|
||||
ctx, r.cancel = context.WithCancel(ctx)
|
||||
|
||||
go r.startResolver(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRoute will stop the dynamic resolver and remove all dynamic routes.
|
||||
// It doesn't touch allowed IPs, these should be removed separately and before calling this method.
|
||||
func (r *Route) RemoveRoute() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.cancel != nil {
|
||||
r.cancel()
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for domain, prefixes := range r.dynamicDomains {
|
||||
for _, prefix := range prefixes {
|
||||
if _, err := r.routeRefCounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %w", prefix, err))
|
||||
}
|
||||
}
|
||||
log.Debugf("Removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
|
||||
|
||||
r.statusRecorder.DeleteResolvedDomainsStates(domain)
|
||||
}
|
||||
|
||||
r.dynamicDomains = domainMap{}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *Route) AddAllowedIPs(peerKey string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for domain, domainPrefixes := range r.dynamicDomains {
|
||||
for _, prefix := range domainPrefixes {
|
||||
if err := r.incrementAllowedIP(domain, prefix, peerKey); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf(addAllowedIP, prefix, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
r.currentPeerKey = peerKey
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *Route) RemoveAllowedIPs() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, domainPrefixes := range r.dynamicDomains {
|
||||
for _, prefix := range domainPrefixes {
|
||||
if _, err := r.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %w", prefix, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
r.currentPeerKey = ""
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *Route) startResolver(ctx context.Context) {
|
||||
log.Debugf("Starting dynamic route resolver for domains [%v]", r)
|
||||
|
||||
interval := r.interval
|
||||
if interval < minInterval {
|
||||
interval = minInterval
|
||||
log.Warnf("Dynamic route resolver interval %s is too low, setting to minimum value %s", r.interval, minInterval)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
if err := r.update(ctx); err != nil {
|
||||
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
|
||||
if interval > failureInterval {
|
||||
ticker.Reset(failureInterval)
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Debugf("Stopping dynamic route resolver for domains [%v]", r)
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := r.update(ctx); err != nil {
|
||||
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
|
||||
// Use a lower ticker interval if the update fails
|
||||
if interval > failureInterval {
|
||||
ticker.Reset(failureInterval)
|
||||
}
|
||||
} else if interval > failureInterval {
|
||||
// Reset to the original interval if the update succeeds
|
||||
ticker.Reset(interval)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Route) update(ctx context.Context) error {
|
||||
if resolved, err := r.resolveDomains(); err != nil {
|
||||
return fmt.Errorf("resolve domains: %w", err)
|
||||
} else if err := r.updateDynamicRoutes(ctx, resolved); err != nil {
|
||||
return fmt.Errorf("update dynamic routes: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Route) resolveDomains() (domainMap, error) {
|
||||
results := make(chan resolveResult)
|
||||
go r.resolve(results)
|
||||
|
||||
resolved := domainMap{}
|
||||
var merr *multierror.Error
|
||||
|
||||
for result := range results {
|
||||
if result.err != nil {
|
||||
merr = multierror.Append(merr, result.err)
|
||||
} else {
|
||||
resolved[result.domain] = append(resolved[result.domain], result.prefix)
|
||||
}
|
||||
}
|
||||
|
||||
return resolved, nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *Route) resolve(results chan resolveResult) {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for _, d := range r.route.Domains {
|
||||
wg.Add(1)
|
||||
go func(domain domain.Domain) {
|
||||
defer wg.Done()
|
||||
ips, err := net.LookupIP(string(domain))
|
||||
if err != nil {
|
||||
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
|
||||
return
|
||||
}
|
||||
for _, ip := range ips {
|
||||
prefix, err := util.GetPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
results <- resolveResult{domain: domain, err: fmt.Errorf("get prefix from IP %s: %w", ip.String(), err)}
|
||||
return
|
||||
}
|
||||
results <- resolveResult{domain: domain, prefix: prefix}
|
||||
}
|
||||
}(d)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
}
|
||||
|
||||
func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
for domain, newPrefixes := range newDomains {
|
||||
oldPrefixes := r.dynamicDomains[domain]
|
||||
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
|
||||
|
||||
addedPrefixes, err := r.addRoutes(domain, toAdd)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
} else if len(addedPrefixes) > 0 {
|
||||
log.Debugf("Added dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", addedPrefixes), " ", ", "))
|
||||
}
|
||||
|
||||
removedPrefixes, err := r.removeRoutes(toRemove)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
} else if len(removedPrefixes) > 0 {
|
||||
log.Debugf("Removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", removedPrefixes), " ", ", "))
|
||||
}
|
||||
|
||||
updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes)
|
||||
r.dynamicDomains[domain] = updatedPrefixes
|
||||
|
||||
r.statusRecorder.UpdateResolvedDomainsStates(domain, updatedPrefixes)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *Route) addRoutes(domain domain.Domain, prefixes []netip.Prefix) ([]netip.Prefix, error) {
|
||||
var addedPrefixes []netip.Prefix
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, prefix := range prefixes {
|
||||
if _, err := r.routeRefCounter.Increment(prefix, nil); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add dynamic route for IP %s: %w", prefix, err))
|
||||
continue
|
||||
}
|
||||
if r.currentPeerKey != "" {
|
||||
if err := r.incrementAllowedIP(domain, prefix, r.currentPeerKey); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf(addAllowedIP, prefix, err))
|
||||
}
|
||||
}
|
||||
addedPrefixes = append(addedPrefixes, prefix)
|
||||
}
|
||||
|
||||
return addedPrefixes, merr.ErrorOrNil()
|
||||
}
|
||||
|
||||
func (r *Route) removeRoutes(prefixes []netip.Prefix) ([]netip.Prefix, error) {
|
||||
if r.route.KeepRoute {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var removedPrefixes []netip.Prefix
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, prefix := range prefixes {
|
||||
if _, err := r.routeRefCounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %w", prefix, err))
|
||||
}
|
||||
if r.currentPeerKey != "" {
|
||||
if _, err := r.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %w", prefix, err))
|
||||
}
|
||||
}
|
||||
removedPrefixes = append(removedPrefixes, prefix)
|
||||
}
|
||||
|
||||
return removedPrefixes, merr.ErrorOrNil()
|
||||
}
|
||||
|
||||
func (r *Route) incrementAllowedIP(domain domain.Domain, prefix netip.Prefix, peerKey string) error {
|
||||
if ref, err := r.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
|
||||
return fmt.Errorf(addAllowedIP, prefix, err)
|
||||
} else if ref.Count > 1 && ref.Out != peerKey {
|
||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||
prefix.Addr(),
|
||||
domain.SafeString(),
|
||||
ref.Out,
|
||||
)
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
|
||||
prefixSet := make(map[netip.Prefix]bool)
|
||||
for _, prefix := range oldPrefixes {
|
||||
prefixSet[prefix] = false
|
||||
}
|
||||
for _, prefix := range newPrefixes {
|
||||
if _, exists := prefixSet[prefix]; exists {
|
||||
prefixSet[prefix] = true
|
||||
} else {
|
||||
toAdd = append(toAdd, prefix)
|
||||
}
|
||||
}
|
||||
for prefix, inUse := range prefixSet {
|
||||
if !inUse {
|
||||
toRemove = append(toRemove, prefix)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes []netip.Prefix) []netip.Prefix {
|
||||
prefixSet := make(map[netip.Prefix]struct{})
|
||||
for _, prefix := range oldPrefixes {
|
||||
prefixSet[prefix] = struct{}{}
|
||||
}
|
||||
for _, prefix := range removedPrefixes {
|
||||
delete(prefixSet, prefix)
|
||||
}
|
||||
for _, prefix := range addedPrefixes {
|
||||
prefixSet[prefix] = struct{}{}
|
||||
}
|
||||
|
||||
var combinedPrefixes []netip.Prefix
|
||||
for prefix := range prefixSet {
|
||||
combinedPrefixes = append(combinedPrefixes, prefix)
|
||||
}
|
||||
|
||||
return combinedPrefixes
|
||||
}
|
||||
@@ -2,23 +2,18 @@ package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -26,85 +21,51 @@ import (
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
var defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
|
||||
// nolint:unused
|
||||
var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
|
||||
// Manager is a route manager interface
|
||||
type Manager interface {
|
||||
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
||||
Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error)
|
||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
||||
TriggerSelection(route.HAMap)
|
||||
GetRouteSelector() *routeselector.RouteSelector
|
||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||
InitialRouteRange() []string
|
||||
ResetV6Routes()
|
||||
EnableServerRouter(firewall firewall.Manager) error
|
||||
Stop()
|
||||
}
|
||||
|
||||
// DefaultManager is the default instance of a route manager
|
||||
type DefaultManager struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
mux sync.Mutex
|
||||
clientNetworks map[route.HAUniqueID]*clientNetwork
|
||||
routeSelector *routeselector.RouteSelector
|
||||
serverRouter serverRouter
|
||||
sysOps *systemops.SysOps
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
pubKey string
|
||||
notifier *notifier
|
||||
routeRefCounter *refcounter.RouteRefCounter
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
||||
dnsRouteInterval time.Duration
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
mux sync.Mutex
|
||||
clientNetworks map[route.HAUniqueID]*clientNetwork
|
||||
routeSelector *routeselector.RouteSelector
|
||||
serverRouter serverRouter
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
pubKey string
|
||||
notifier *notifier
|
||||
}
|
||||
|
||||
func NewManager(
|
||||
ctx context.Context,
|
||||
pubKey string,
|
||||
dnsRouteInterval time.Duration,
|
||||
wgInterface *iface.WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
initialRoutes []*route.Route,
|
||||
) *DefaultManager {
|
||||
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager {
|
||||
mCTX, cancel := context.WithCancel(ctx)
|
||||
sysOps := systemops.NewSysOps(wgInterface)
|
||||
|
||||
dm := &DefaultManager{
|
||||
ctx: mCTX,
|
||||
stop: cancel,
|
||||
dnsRouteInterval: dnsRouteInterval,
|
||||
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
|
||||
routeSelector: routeselector.NewRouteSelector(),
|
||||
sysOps: sysOps,
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
pubKey: pubKey,
|
||||
notifier: newNotifier(),
|
||||
ctx: mCTX,
|
||||
stop: cancel,
|
||||
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
|
||||
routeSelector: routeselector.NewRouteSelector(),
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
pubKey: pubKey,
|
||||
notifier: newNotifier(),
|
||||
}
|
||||
|
||||
dm.routeRefCounter = refcounter.New(
|
||||
func(prefix netip.Prefix, _ any) (any, error) {
|
||||
return nil, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface())
|
||||
},
|
||||
func(prefix netip.Prefix, _ any) error {
|
||||
return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface())
|
||||
},
|
||||
)
|
||||
|
||||
dm.allowedIPsRefCounter = refcounter.New(
|
||||
func(prefix netip.Prefix, peerKey string) (string, error) {
|
||||
// save peerKey to use it in the remove function
|
||||
return peerKey, wgInterface.AddAllowedIP(peerKey, prefix.String())
|
||||
},
|
||||
func(prefix netip.Prefix, peerKey string) error {
|
||||
if err := wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil {
|
||||
if !errors.Is(err, iface.ErrPeerNotFound) && !errors.Is(err, iface.ErrAllowedIPNotFound) {
|
||||
return err
|
||||
}
|
||||
log.Tracef("Remove allowed IPs %s for %s: %v", prefix, peerKey, err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
if runtime.GOOS == "android" {
|
||||
cr := dm.clientRoutes(initialRoutes)
|
||||
dm.notifier.setInitialClientRoutes(cr)
|
||||
@@ -113,12 +74,12 @@ func NewManager(
|
||||
}
|
||||
|
||||
// Init sets up the routing
|
||||
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
if nbnet.CustomRoutingDisabled() {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
if err := m.sysOps.CleanupRouting(); err != nil {
|
||||
if err := cleanupRouting(); err != nil {
|
||||
log.Warnf("Failed cleaning up routing: %v", err)
|
||||
}
|
||||
|
||||
@@ -126,7 +87,7 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
||||
signalAddress := m.statusRecorder.GetSignalState().URL
|
||||
ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress})
|
||||
|
||||
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips)
|
||||
beforePeerHook, afterPeerHook, err := setupRouting(ips, m.wgInterface)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("setup routing: %w", err)
|
||||
}
|
||||
@@ -150,19 +111,8 @@ func (m *DefaultManager) Stop() {
|
||||
m.serverRouter.cleanUp()
|
||||
}
|
||||
|
||||
if m.routeRefCounter != nil {
|
||||
if err := m.routeRefCounter.Flush(); err != nil {
|
||||
log.Errorf("Error flushing route ref counter: %v", err)
|
||||
}
|
||||
}
|
||||
if m.allowedIPsRefCounter != nil {
|
||||
if err := m.allowedIPsRefCounter.Flush(); err != nil {
|
||||
log.Errorf("Error flushing allowed IPs ref counter: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !nbnet.CustomRoutingDisabled() {
|
||||
if err := m.sysOps.CleanupRouting(); err != nil {
|
||||
if err := cleanupRouting(); err != nil {
|
||||
log.Errorf("Error cleaning up routing: %v", err)
|
||||
} else {
|
||||
log.Info("Routing cleanup complete")
|
||||
@@ -199,6 +149,18 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
||||
}
|
||||
}
|
||||
|
||||
// ResetV6Routes deletes all IPv6 routes (necessary if IPv6 address changes).
|
||||
// It is expected that UpdateRoute is called afterwards to recreate the routing table.
|
||||
func (m *DefaultManager) ResetV6Routes() {
|
||||
for id, client := range m.clientNetworks {
|
||||
if client.network.Addr().Is6() {
|
||||
log.Debugf("stopping client network watcher due to IPv6 address change, %s", id)
|
||||
client.stop()
|
||||
delete(m.clientNetworks, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetRouteChangeListener set RouteListener for route change notifier
|
||||
func (m *DefaultManager) SetRouteChangeListener(listener listener.NetworkChangeListener) {
|
||||
m.notifier.setListener(listener)
|
||||
@@ -236,7 +198,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
||||
continue
|
||||
}
|
||||
|
||||
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
|
||||
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
|
||||
@@ -248,7 +210,7 @@ func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) {
|
||||
for id, client := range m.clientNetworks {
|
||||
if _, ok := networks[id]; !ok {
|
||||
log.Debugf("Stopping client network watcher, %s", id)
|
||||
client.cancel()
|
||||
client.stop()
|
||||
delete(m.clientNetworks, id)
|
||||
}
|
||||
}
|
||||
@@ -261,7 +223,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
|
||||
for id, routes := range networks {
|
||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||
if !found {
|
||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
|
||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
}
|
||||
@@ -279,7 +241,7 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
|
||||
ownNetworkIDs := make(map[route.HAUniqueID]bool)
|
||||
|
||||
for _, newRoute := range newRoutes {
|
||||
haID := newRoute.GetHAUniqueID()
|
||||
haID := route.GetHAUniqueID(newRoute)
|
||||
if newRoute.Peer == m.pubKey {
|
||||
ownNetworkIDs[haID] = true
|
||||
// only linux is supported for now
|
||||
@@ -292,9 +254,9 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
|
||||
}
|
||||
|
||||
for _, newRoute := range newRoutes {
|
||||
haID := newRoute.GetHAUniqueID()
|
||||
haID := route.GetHAUniqueID(newRoute)
|
||||
if !ownNetworkIDs[haID] {
|
||||
if !isRouteSupported(newRoute) {
|
||||
if !isPrefixSupported(newRoute.Network) {
|
||||
continue
|
||||
}
|
||||
newClientRoutesIDMap[haID] = append(newClientRoutesIDMap[haID], newRoute)
|
||||
@@ -306,23 +268,23 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
|
||||
|
||||
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
|
||||
_, crMap := m.classifyRoutes(initialRoutes)
|
||||
rs := make([]*route.Route, 0, len(crMap))
|
||||
rs := make([]*route.Route, 0)
|
||||
for _, routes := range crMap {
|
||||
rs = append(rs, routes...)
|
||||
}
|
||||
return rs
|
||||
}
|
||||
|
||||
func isRouteSupported(route *route.Route) bool {
|
||||
if !nbnet.CustomRoutingDisabled() || route.IsDynamic() {
|
||||
func isPrefixSupported(prefix netip.Prefix) bool {
|
||||
if !nbnet.CustomRoutingDisabled() {
|
||||
return true
|
||||
}
|
||||
|
||||
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported
|
||||
// we skip this prefix management
|
||||
if route.Network.Bits() <= vars.MinRangeBits {
|
||||
if prefix.Bits() <= minRangeBits {
|
||||
log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix",
|
||||
version.NetbirdVersion(), route.Network)
|
||||
version.NetbirdVersion(), prefix)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
||||
@@ -36,6 +36,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
serverRoutesExpected int
|
||||
clientNetworkWatchersExpected int
|
||||
clientNetworkWatchersExpectedAllowed int
|
||||
isV6 bool
|
||||
}{
|
||||
{
|
||||
name: "Should create 2 client networks",
|
||||
@@ -65,6 +66,35 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
inputSerial: 1,
|
||||
clientNetworkWatchersExpected: 2,
|
||||
},
|
||||
{
|
||||
name: "Should create 2 client networks (IPv6)",
|
||||
inputInitRoutes: []*route.Route{},
|
||||
inputRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
NetID: "routeB",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("2001:db8::7890:abcd/128"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
clientNetworkWatchersExpected: 2,
|
||||
isV6: true,
|
||||
},
|
||||
{
|
||||
name: "Should Create 2 Server Routes",
|
||||
inputRoutes: []*route.Route{
|
||||
@@ -93,6 +123,34 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
serverRoutesExpected: 2,
|
||||
clientNetworkWatchersExpected: 0,
|
||||
},
|
||||
{
|
||||
name: "Should Create 2 Server Routes (IPv6)",
|
||||
inputRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
NetID: "routeB",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("2001:db8::7890:abcd/128"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
serverRoutesExpected: 2,
|
||||
clientNetworkWatchersExpected: 0,
|
||||
},
|
||||
{
|
||||
name: "Should Create 1 Route For Client And Server",
|
||||
inputRoutes: []*route.Route{
|
||||
@@ -121,6 +179,84 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
serverRoutesExpected: 1,
|
||||
clientNetworkWatchersExpected: 1,
|
||||
},
|
||||
{
|
||||
name: "Should Create 1 Route For Client And Server (IPv6)",
|
||||
inputRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
NetID: "routeB",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("2001:db8::7890:abcd/128"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
serverRoutesExpected: 1,
|
||||
clientNetworkWatchersExpected: 1,
|
||||
isV6: true,
|
||||
},
|
||||
{
|
||||
name: "Should Create 1 Route For Client And Server for each IP version",
|
||||
inputRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("100.64.30.250/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
NetID: "routeB",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("8.8.9.9/32"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
NetID: "routeB",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("2001:db8::7890:abcd/128"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
serverRoutesExpected: 2,
|
||||
clientNetworkWatchersExpected: 2,
|
||||
isV6: true,
|
||||
},
|
||||
{
|
||||
name: "Should Create 1 Route For Client and Skip Server Route On Empty Server Router",
|
||||
inputRoutes: []*route.Route{
|
||||
@@ -150,6 +286,36 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
serverRoutesExpected: 0,
|
||||
clientNetworkWatchersExpected: 1,
|
||||
},
|
||||
{
|
||||
name: "Should Create 1 Route For Client and Skip Server Route On Empty Server Router (IPv6)",
|
||||
inputRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
NetID: "routeB",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("2001:db8::7890:abcd/128"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
removeSrvRouter: true,
|
||||
serverRoutesExpected: 0,
|
||||
clientNetworkWatchersExpected: 1,
|
||||
isV6: true,
|
||||
},
|
||||
{
|
||||
name: "Should Create 1 HA Route and 1 Standalone",
|
||||
inputRoutes: []*route.Route{
|
||||
@@ -187,6 +353,44 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
inputSerial: 1,
|
||||
clientNetworkWatchersExpected: 2,
|
||||
},
|
||||
{
|
||||
name: "Should Create 1 HA Route and 1 Standalone (IPv6)",
|
||||
inputRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey2,
|
||||
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "c",
|
||||
NetID: "routeB",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("2001:db8::7890:abcd/128"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
clientNetworkWatchersExpected: 2,
|
||||
isV6: true,
|
||||
},
|
||||
{
|
||||
name: "No Small Client Route Should Be Added",
|
||||
inputRoutes: []*route.Route{
|
||||
@@ -205,6 +409,25 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
clientNetworkWatchersExpected: 0,
|
||||
clientNetworkWatchersExpectedAllowed: 1,
|
||||
},
|
||||
{
|
||||
name: "No Small Client Route Should Be Added (IPv6)",
|
||||
inputRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("::/0"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
clientNetworkWatchersExpected: 0,
|
||||
clientNetworkWatchersExpectedAllowed: 1,
|
||||
isV6: true,
|
||||
},
|
||||
{
|
||||
name: "Remove 1 Client Route",
|
||||
inputInitRoutes: []*route.Route{
|
||||
@@ -244,6 +467,46 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
inputSerial: 1,
|
||||
clientNetworkWatchersExpected: 1,
|
||||
},
|
||||
{
|
||||
name: "Remove 1 Client Route (IPv6)",
|
||||
inputInitRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
NetID: "routeB",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("2001:db8::abcd:7890/128"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
clientNetworkWatchersExpected: 1,
|
||||
isV6: true,
|
||||
},
|
||||
{
|
||||
name: "Update Route to HA",
|
||||
inputInitRoutes: []*route.Route{
|
||||
@@ -293,6 +556,56 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
inputSerial: 1,
|
||||
clientNetworkWatchersExpected: 1,
|
||||
},
|
||||
{
|
||||
name: "Update Route to HA (IPv6)",
|
||||
inputInitRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
NetID: "routeB",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("2001:db8::abcd:7890/128"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey2,
|
||||
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
clientNetworkWatchersExpected: 1,
|
||||
isV6: true,
|
||||
},
|
||||
{
|
||||
name: "Remove Client Routes",
|
||||
inputInitRoutes: []*route.Route{
|
||||
@@ -321,6 +634,35 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
inputSerial: 1,
|
||||
clientNetworkWatchersExpected: 0,
|
||||
},
|
||||
{
|
||||
name: "Remove Client Routes (IPv6)",
|
||||
inputInitRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
NetID: "routeB",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("2001:db8::abcd:7890/128"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputRoutes: []*route.Route{},
|
||||
inputSerial: 1,
|
||||
clientNetworkWatchersExpected: 0,
|
||||
isV6: true,
|
||||
},
|
||||
{
|
||||
name: "Remove All Routes",
|
||||
inputInitRoutes: []*route.Route{
|
||||
@@ -350,6 +692,36 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
serverRoutesExpected: 0,
|
||||
clientNetworkWatchersExpected: 0,
|
||||
},
|
||||
{
|
||||
name: "Remove All Routes (IPv6)",
|
||||
inputInitRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
NetID: "routeB",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("2001:db8::abcd:7890/128"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputRoutes: []*route.Route{},
|
||||
inputSerial: 1,
|
||||
serverRoutesExpected: 0,
|
||||
clientNetworkWatchersExpected: 0,
|
||||
isV6: true,
|
||||
},
|
||||
{
|
||||
name: "HA server should not register routes from the same HA group",
|
||||
inputRoutes: []*route.Route{
|
||||
@@ -398,16 +770,74 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
serverRoutesExpected: 2,
|
||||
clientNetworkWatchersExpected: 1,
|
||||
},
|
||||
{
|
||||
name: "HA server should not register routes from the same HA group (IPv6)",
|
||||
inputRoutes: []*route.Route{
|
||||
{
|
||||
ID: "l1",
|
||||
NetID: "routeA",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "l2",
|
||||
NetID: "routeA",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("2001:db8::abcd:7890/128"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "r1",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "r2",
|
||||
NetID: "routeC",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("2001:db8::abcd:789f/128"),
|
||||
NetworkType: route.IPv6Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
serverRoutesExpected: 2,
|
||||
clientNetworkWatchersExpected: 1,
|
||||
isV6: true,
|
||||
},
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
v6Addr := ""
|
||||
//goland:noinspection GoBoolExpressions
|
||||
if !iface.SupportsIPv6() && testCase.isV6 {
|
||||
t.Skip("Platform does not support IPv6, skipping IPv6 test...")
|
||||
} else if testCase.isV6 {
|
||||
v6Addr = "2001:db8::4242:4711/128"
|
||||
}
|
||||
|
||||
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", v6Addr, 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
@@ -416,7 +846,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
|
||||
statusRecorder := peer.NewRecorder("https://mgm")
|
||||
ctx := context.TODO()
|
||||
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil)
|
||||
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
|
||||
|
||||
_, _, err = routeManager.Init()
|
||||
|
||||
@@ -436,7 +866,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
require.NoError(t, err, "should update routes")
|
||||
|
||||
expectedWatchers := testCase.clientNetworkWatchersExpected
|
||||
if testCase.clientNetworkWatchersExpectedAllowed != 0 {
|
||||
if (runtime.GOOS == "linux" || runtime.GOOS == "windows" || runtime.GOOS == "darwin") && testCase.clientNetworkWatchersExpectedAllowed != 0 {
|
||||
expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed
|
||||
}
|
||||
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")
|
||||
|
||||
@@ -6,10 +6,10 @@ import (
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// MockManager is the mock instance of a route manager
|
||||
@@ -20,7 +20,7 @@ type MockManager struct {
|
||||
StopFunc func()
|
||||
}
|
||||
|
||||
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
|
||||
func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
@@ -64,6 +64,10 @@ func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *MockManager) ResetV6Routes() {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
// Stop mock implementation of Stop from Manager interface
|
||||
func (m *MockManager) Stop() {
|
||||
if m.StopFunc != nil {
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
package refcounter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
)
|
||||
|
||||
// ErrIgnore can be returned by AddFunc to indicate that the counter not be incremented for the given prefix.
|
||||
var ErrIgnore = errors.New("ignore")
|
||||
|
||||
type Ref[O any] struct {
|
||||
Count int
|
||||
Out O
|
||||
}
|
||||
|
||||
type AddFunc[I, O any] func(prefix netip.Prefix, in I) (out O, err error)
|
||||
type RemoveFunc[I, O any] func(prefix netip.Prefix, out O) error
|
||||
|
||||
type Counter[I, O any] struct {
|
||||
// refCountMap keeps track of the reference Ref for prefixes
|
||||
refCountMap map[netip.Prefix]Ref[O]
|
||||
refCountMu sync.Mutex
|
||||
// idMap keeps track of the prefixes associated with an ID for removal
|
||||
idMap map[string][]netip.Prefix
|
||||
idMu sync.Mutex
|
||||
add AddFunc[I, O]
|
||||
remove RemoveFunc[I, O]
|
||||
}
|
||||
|
||||
// New creates a new Counter instance
|
||||
func New[I, O any](add AddFunc[I, O], remove RemoveFunc[I, O]) *Counter[I, O] {
|
||||
return &Counter[I, O]{
|
||||
refCountMap: map[netip.Prefix]Ref[O]{},
|
||||
idMap: map[string][]netip.Prefix{},
|
||||
add: add,
|
||||
remove: remove,
|
||||
}
|
||||
}
|
||||
|
||||
// Increment increments the reference count for the given prefix.
|
||||
// If this is the first reference to the prefix, the AddFunc is called.
|
||||
func (rm *Counter[I, O]) Increment(prefix netip.Prefix, in I) (Ref[O], error) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
|
||||
ref := rm.refCountMap[prefix]
|
||||
log.Tracef("Increasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
|
||||
|
||||
// Call AddFunc only if it's a new prefix
|
||||
if ref.Count == 0 {
|
||||
log.Tracef("Adding for prefix %s with [%v]", prefix, ref.Out)
|
||||
out, err := rm.add(prefix, in)
|
||||
|
||||
if errors.Is(err, ErrIgnore) {
|
||||
return ref, nil
|
||||
}
|
||||
if err != nil {
|
||||
return ref, fmt.Errorf("failed to add for prefix %s: %w", prefix, err)
|
||||
}
|
||||
ref.Out = out
|
||||
}
|
||||
|
||||
ref.Count++
|
||||
rm.refCountMap[prefix] = ref
|
||||
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
// IncrementWithID increments the reference count for the given prefix and groups it under the given ID.
|
||||
// If this is the first reference to the prefix, the AddFunc is called.
|
||||
func (rm *Counter[I, O]) IncrementWithID(id string, prefix netip.Prefix, in I) (Ref[O], error) {
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
ref, err := rm.Increment(prefix, in)
|
||||
if err != nil {
|
||||
return ref, fmt.Errorf("with ID: %w", err)
|
||||
}
|
||||
rm.idMap[id] = append(rm.idMap[id], prefix)
|
||||
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
// Decrement decrements the reference count for the given prefix.
|
||||
// If the reference count reaches 0, the RemoveFunc is called.
|
||||
func (rm *Counter[I, O]) Decrement(prefix netip.Prefix) (Ref[O], error) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
|
||||
ref, ok := rm.refCountMap[prefix]
|
||||
if !ok {
|
||||
log.Tracef("No reference found for prefix %s", prefix)
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
log.Tracef("Decreasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
|
||||
if ref.Count == 1 {
|
||||
log.Tracef("Removing for prefix %s with [%v]", prefix, ref.Out)
|
||||
if err := rm.remove(prefix, ref.Out); err != nil {
|
||||
return ref, fmt.Errorf("remove for prefix %s: %w", prefix, err)
|
||||
}
|
||||
delete(rm.refCountMap, prefix)
|
||||
} else {
|
||||
ref.Count--
|
||||
rm.refCountMap[prefix] = ref
|
||||
}
|
||||
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
// DecrementWithID decrements the reference count for all prefixes associated with the given ID.
|
||||
// If the reference count reaches 0, the RemoveFunc is called.
|
||||
func (rm *Counter[I, O]) DecrementWithID(id string) error {
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, prefix := range rm.idMap[id] {
|
||||
if _, err := rm.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
}
|
||||
delete(rm.idMap, id)
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// Flush removes all references and calls RemoveFunc for each prefix.
|
||||
func (rm *Counter[I, O]) Flush() error {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for prefix := range rm.refCountMap {
|
||||
log.Tracef("Removing for prefix %s", prefix)
|
||||
ref := rm.refCountMap[prefix]
|
||||
if err := rm.remove(prefix, ref.Out); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove for prefix %s: %w", prefix, err))
|
||||
}
|
||||
}
|
||||
rm.refCountMap = map[netip.Prefix]Ref[O]{}
|
||||
|
||||
rm.idMap = map[string][]netip.Prefix{}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
package refcounter
|
||||
|
||||
// RouteRefCounter is a Counter for Route, it doesn't take any input on Increment and doesn't use any output on Decrement
|
||||
type RouteRefCounter = Counter[any, any]
|
||||
|
||||
// AllowedIPsRefCounter is a Counter for AllowedIPs, it takes a peer key on Increment and passes it back to Decrement
|
||||
type AllowedIPsRefCounter = Counter[string, string]
|
||||
127
client/internal/routemanager/routemanager.go
Normal file
127
client/internal/routemanager/routemanager.go
Normal file
@@ -0,0 +1,127 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type ref struct {
|
||||
count int
|
||||
nexthop netip.Addr
|
||||
intf *net.Interface
|
||||
}
|
||||
|
||||
type RouteManager struct {
|
||||
// refCountMap keeps track of the reference ref for prefixes
|
||||
refCountMap map[netip.Prefix]ref
|
||||
// prefixMap keeps track of the prefixes associated with a connection ID for removal
|
||||
prefixMap map[nbnet.ConnectionID][]netip.Prefix
|
||||
addRoute AddRouteFunc
|
||||
removeRoute RemoveRouteFunc
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf *net.Interface, err error)
|
||||
type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error
|
||||
|
||||
func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager {
|
||||
// TODO: read initial routing table into refCountMap
|
||||
return &RouteManager{
|
||||
refCountMap: map[netip.Prefix]ref{},
|
||||
prefixMap: map[nbnet.ConnectionID][]netip.Prefix{},
|
||||
addRoute: addRoute,
|
||||
removeRoute: removeRoute,
|
||||
}
|
||||
}
|
||||
|
||||
func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Prefix) error {
|
||||
rm.mutex.Lock()
|
||||
defer rm.mutex.Unlock()
|
||||
|
||||
ref := rm.refCountMap[prefix]
|
||||
log.Debugf("Increasing route ref count %d for prefix %s", ref.count, prefix)
|
||||
|
||||
// Add route to the system, only if it's a new prefix
|
||||
if ref.count == 0 {
|
||||
log.Debugf("Adding route for prefix %s", prefix)
|
||||
nexthop, intf, err := rm.addRoute(prefix)
|
||||
if errors.Is(err, ErrRouteNotFound) {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, ErrRouteNotAllowed) {
|
||||
log.Debugf("Adding route for prefix %s: %s", prefix, err)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err)
|
||||
}
|
||||
ref.nexthop = nexthop
|
||||
ref.intf = intf
|
||||
}
|
||||
|
||||
ref.count++
|
||||
rm.refCountMap[prefix] = ref
|
||||
rm.prefixMap[connID] = append(rm.prefixMap[connID], prefix)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rm *RouteManager) RemoveRouteRef(connID nbnet.ConnectionID) error {
|
||||
rm.mutex.Lock()
|
||||
defer rm.mutex.Unlock()
|
||||
|
||||
prefixes, ok := rm.prefixMap[connID]
|
||||
if !ok {
|
||||
log.Debugf("No prefixes found for connection ID %s", connID)
|
||||
return nil
|
||||
}
|
||||
|
||||
var result *multierror.Error
|
||||
for _, prefix := range prefixes {
|
||||
ref := rm.refCountMap[prefix]
|
||||
log.Debugf("Decreasing route ref count %d for prefix %s", ref.count, prefix)
|
||||
if ref.count == 1 {
|
||||
log.Debugf("Removing route for prefix %s", prefix)
|
||||
// TODO: don't fail if the route is not found
|
||||
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
|
||||
continue
|
||||
}
|
||||
delete(rm.refCountMap, prefix)
|
||||
} else {
|
||||
ref.count--
|
||||
rm.refCountMap[prefix] = ref
|
||||
}
|
||||
}
|
||||
delete(rm.prefixMap, connID)
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
// Flush removes all references and routes from the system
|
||||
func (rm *RouteManager) Flush() error {
|
||||
rm.mutex.Lock()
|
||||
defer rm.mutex.Unlock()
|
||||
|
||||
var result *multierror.Error
|
||||
for prefix := range rm.refCountMap {
|
||||
log.Debugf("Removing route for prefix %s", prefix)
|
||||
ref := rm.refCountMap[prefix]
|
||||
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
|
||||
}
|
||||
}
|
||||
rm.refCountMap = map[netip.Prefix]ref{}
|
||||
rm.prefixMap = map[nbnet.ConnectionID][]netip.Prefix{}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -71,7 +70,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route)
|
||||
}
|
||||
|
||||
if len(m.routes) > 0 {
|
||||
err := systemops.EnableIPForwarding()
|
||||
err := enableIPForwarding(m.wgInterface.Address6() != nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -80,7 +79,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error {
|
||||
func (m *defaultServerRouter) removeFromServerNetwork(rt *route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("Not removing from server network because context is done")
|
||||
@@ -88,28 +87,32 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error
|
||||
default:
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
routerPair, err := routeToRouterPair(route)
|
||||
routingAddress := m.wgInterface.Address().Masked().String()
|
||||
if rt.NetworkType == route.IPv6Network {
|
||||
if m.wgInterface.Address6() == nil {
|
||||
return fmt.Errorf("attempted to add route for IPv6 even though device has no v6 address")
|
||||
}
|
||||
routingAddress = m.wgInterface.Address6().Masked().String()
|
||||
}
|
||||
routerPair, err := routeToRouterPair(routingAddress, rt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse prefix: %w", err)
|
||||
}
|
||||
|
||||
err = m.firewall.RemoveRoutingRules(routerPair)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove routing rules: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
delete(m.routes, route.ID)
|
||||
delete(m.routes, rt.ID)
|
||||
|
||||
state := m.statusRecorder.GetLocalPeerState()
|
||||
delete(state.Routes, route.Network.String())
|
||||
delete(state.Routes, rt.Network.String())
|
||||
m.statusRecorder.UpdateLocalPeerState(state)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
|
||||
func (m *defaultServerRouter) addToServerNetwork(rt *route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("Not adding to server network because context is done")
|
||||
@@ -117,8 +120,15 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
|
||||
default:
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
routingAddress := m.wgInterface.Address().Masked().String()
|
||||
if rt.NetworkType == route.IPv6Network {
|
||||
if m.wgInterface.Address6() == nil {
|
||||
return fmt.Errorf("attempted to add route for IPv6 even though device has no v6 address")
|
||||
}
|
||||
routingAddress = m.wgInterface.Address6().Masked().String()
|
||||
}
|
||||
|
||||
routerPair, err := routeToRouterPair(route)
|
||||
routerPair, err := routeToRouterPair(routingAddress, rt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse prefix: %w", err)
|
||||
}
|
||||
@@ -128,19 +138,13 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
|
||||
return fmt.Errorf("insert routing rules: %w", err)
|
||||
}
|
||||
|
||||
m.routes[route.ID] = route
|
||||
m.routes[rt.ID] = rt
|
||||
|
||||
state := m.statusRecorder.GetLocalPeerState()
|
||||
if state.Routes == nil {
|
||||
state.Routes = map[string]struct{}{}
|
||||
}
|
||||
|
||||
routeStr := route.Network.String()
|
||||
if route.IsDynamic() {
|
||||
routeStr = route.Domains.SafeString()
|
||||
}
|
||||
state.Routes[routeStr] = struct{}{}
|
||||
|
||||
state.Routes[rt.Network.String()] = struct{}{}
|
||||
m.statusRecorder.UpdateLocalPeerState(state)
|
||||
|
||||
return nil
|
||||
@@ -151,10 +155,17 @@ func (m *defaultServerRouter) cleanUp() {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
for _, r := range m.routes {
|
||||
routerPair, err := routeToRouterPair(r)
|
||||
routingAddress := m.wgInterface.Address().Masked().String()
|
||||
if r.NetworkType == route.IPv6Network {
|
||||
if m.wgInterface.Address6() == nil {
|
||||
log.Errorf("attempted to remove route for IPv6 even though device has no v6 address")
|
||||
continue
|
||||
}
|
||||
routingAddress = m.wgInterface.Address6().Masked().String()
|
||||
}
|
||||
routerPair, err := routeToRouterPair(routingAddress, r)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to convert route to router pair: %v", err)
|
||||
continue
|
||||
log.Errorf("parse prefix: %v", err)
|
||||
}
|
||||
|
||||
err = m.firewall.RemoveRoutingRules(routerPair)
|
||||
@@ -169,27 +180,15 @@ func (m *defaultServerRouter) cleanUp() {
|
||||
m.statusRecorder.UpdateLocalPeerState(state)
|
||||
}
|
||||
|
||||
func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) {
|
||||
// TODO: add ipv6
|
||||
source := getDefaultPrefix(route.Network)
|
||||
|
||||
destination := route.Network.Masked().String()
|
||||
if route.IsDynamic() {
|
||||
// TODO: add ipv6
|
||||
destination = "0.0.0.0/0"
|
||||
func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) {
|
||||
parsed, err := netip.ParsePrefix(source)
|
||||
if err != nil {
|
||||
return firewall.RouterPair{}, err
|
||||
}
|
||||
|
||||
return firewall.RouterPair{
|
||||
ID: string(route.ID),
|
||||
Source: source.String(),
|
||||
Destination: destination,
|
||||
Source: parsed.String(),
|
||||
Destination: route.Network.Masked().String(),
|
||||
Masquerade: route.Masquerade,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getDefaultPrefix(prefix netip.Prefix) netip.Prefix {
|
||||
if prefix.Addr().Is6() {
|
||||
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
}
|
||||
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
}
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
package static
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type Route struct {
|
||||
route *route.Route
|
||||
routeRefCounter *refcounter.RouteRefCounter
|
||||
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
||||
}
|
||||
|
||||
func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *Route {
|
||||
return &Route{
|
||||
route: rt,
|
||||
routeRefCounter: routeRefCounter,
|
||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||
}
|
||||
}
|
||||
|
||||
// Route route methods
|
||||
func (r *Route) String() string {
|
||||
return r.route.Network.String()
|
||||
}
|
||||
|
||||
func (r *Route) AddRoute(context.Context) error {
|
||||
_, err := r.routeRefCounter.Increment(r.route.Network, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Route) RemoveRoute() error {
|
||||
_, err := r.routeRefCounter.Decrement(r.route.Network)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Route) AddAllowedIPs(peerKey string) error {
|
||||
if ref, err := r.allowedIPsRefcounter.Increment(r.route.Network, peerKey); err != nil {
|
||||
return fmt.Errorf("add allowed IP %s: %w", r.route.Network, err)
|
||||
} else if ref.Count > 1 && ref.Out != peerKey {
|
||||
log.Warnf("Prefix [%s] is already routed by peer [%s]. HA routing disabled",
|
||||
r.route.Network,
|
||||
ref.Out,
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Route) RemoveAllowedIPs() error {
|
||||
_, err := r.allowedIPsRefcounter.Decrement(r.route.Network)
|
||||
return err
|
||||
}
|
||||
@@ -1,103 +0,0 @@
|
||||
// go:build !android
|
||||
package sysctl
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
rpFilterPath = "net.ipv4.conf.all.rp_filter"
|
||||
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
|
||||
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
|
||||
)
|
||||
|
||||
// Setup configures sysctl settings for RP filtering and source validation.
|
||||
func Setup(wgIface *iface.WGIface) (map[string]int, error) {
|
||||
keys := map[string]int{}
|
||||
var result *multierror.Error
|
||||
|
||||
oldVal, err := Set(srcValidMarkPath, 1, false)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[srcValidMarkPath] = oldVal
|
||||
}
|
||||
|
||||
oldVal, err = Set(rpFilterPath, 2, true)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[rpFilterPath] = oldVal
|
||||
}
|
||||
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
|
||||
}
|
||||
|
||||
for _, intf := range interfaces {
|
||||
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
|
||||
continue
|
||||
}
|
||||
|
||||
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
|
||||
oldVal, err := Set(i, 2, true)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[i] = oldVal
|
||||
}
|
||||
}
|
||||
|
||||
return keys, nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// Set sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
|
||||
func Set(key string, desiredValue int, onlyIfOne bool) (int, error) {
|
||||
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
|
||||
currentValue, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
|
||||
}
|
||||
|
||||
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
|
||||
if err != nil && len(currentValue) > 0 {
|
||||
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
|
||||
}
|
||||
|
||||
if currentV == desiredValue || onlyIfOne && currentV != 1 {
|
||||
return currentV, nil
|
||||
}
|
||||
|
||||
//nolint:gosec
|
||||
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
|
||||
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
|
||||
}
|
||||
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
|
||||
|
||||
return currentV, nil
|
||||
}
|
||||
|
||||
// Cleanup resets sysctl settings to their original values.
|
||||
func Cleanup(originalSettings map[string]int) error {
|
||||
var result *multierror.Error
|
||||
|
||||
for key, value := range originalSettings {
|
||||
_, err := Set(key, value, false)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
424
client/internal/routemanager/systemops.go
Normal file
424
client/internal/routemanager/systemops.go
Normal file
@@ -0,0 +1,424 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strconv"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/libp2p/go-netroute"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
|
||||
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
|
||||
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
||||
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
||||
|
||||
var ErrRouteNotFound = errors.New("route not found")
|
||||
var ErrRouteNotAllowed = errors.New("route not allowed")
|
||||
|
||||
// TODO: fix: for default our wg address now appears as the default gw
|
||||
func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
||||
addr := netip.IPv4Unspecified()
|
||||
if prefix.Addr().Is6() {
|
||||
addr = netip.IPv6Unspecified()
|
||||
}
|
||||
|
||||
defaultGateway, _, err := GetNextHop(addr)
|
||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||
return fmt.Errorf("get existing route gateway: %s", err)
|
||||
}
|
||||
|
||||
if !prefix.Contains(defaultGateway) {
|
||||
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
gatewayPrefix := netip.PrefixFrom(defaultGateway, 32)
|
||||
if defaultGateway.Is6() {
|
||||
gatewayPrefix = netip.PrefixFrom(defaultGateway, 128)
|
||||
}
|
||||
|
||||
ok, err := existsInRouteTable(gatewayPrefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
|
||||
}
|
||||
|
||||
if ok {
|
||||
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
gatewayHop, intf, err := GetNextHop(defaultGateway)
|
||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
||||
}
|
||||
|
||||
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
|
||||
return addToRouteTable(gatewayPrefix, gatewayHop, intf)
|
||||
}
|
||||
|
||||
func GetNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
|
||||
r, err := netroute.New()
|
||||
if err != nil {
|
||||
return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err)
|
||||
}
|
||||
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
|
||||
if err != nil {
|
||||
log.Debugf("Failed to get route for %s: %v", ip, err)
|
||||
return netip.Addr{}, nil, ErrRouteNotFound
|
||||
}
|
||||
|
||||
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
|
||||
if gateway == nil {
|
||||
if preferredSrc == nil {
|
||||
return netip.Addr{}, nil, ErrRouteNotFound
|
||||
}
|
||||
log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc)
|
||||
|
||||
addr, err := ipToAddr(preferredSrc, intf)
|
||||
if err != nil {
|
||||
return netip.Addr{}, nil, fmt.Errorf("convert preferred source to address: %w", err)
|
||||
}
|
||||
return addr.Unmap(), intf, nil
|
||||
}
|
||||
|
||||
addr, err := ipToAddr(gateway, intf)
|
||||
if err != nil {
|
||||
return netip.Addr{}, nil, fmt.Errorf("convert gateway to address: %w", err)
|
||||
}
|
||||
|
||||
return addr, intf, nil
|
||||
}
|
||||
|
||||
// converts a net.IP to a netip.Addr including the zone based on the passed interface
|
||||
func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip)
|
||||
}
|
||||
|
||||
if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) {
|
||||
log.Tracef("Adding zone %s to address %s", intf.Name, addr)
|
||||
if runtime.GOOS == "windows" {
|
||||
addr = addr.WithZone(strconv.Itoa(intf.Index))
|
||||
} else {
|
||||
addr = addr.WithZone(intf.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return addr.Unmap(), nil
|
||||
}
|
||||
|
||||
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
||||
|
||||
linkLocalPrefix, err := netip.ParsePrefix("fe80::/10")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if prefix.Addr().Is6() && linkLocalPrefix.Contains(prefix.Addr()) {
|
||||
// The link local prefix is not explicitly part of the routing table, but should be considered as such.
|
||||
return true, nil
|
||||
}
|
||||
|
||||
routes, err := getRoutesFromTable()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get routes from table: %w", err)
|
||||
}
|
||||
for _, tableRoute := range routes {
|
||||
if tableRoute == prefix {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func isSubRange(prefix netip.Prefix) (bool, error) {
|
||||
routes, err := getRoutesFromTable()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get routes from table: %w", err)
|
||||
}
|
||||
for _, tableRoute := range routes {
|
||||
if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
|
||||
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
|
||||
func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop netip.Addr, initialIntf *net.Interface) (netip.Addr, *net.Interface, error) {
|
||||
addr := prefix.Addr()
|
||||
switch {
|
||||
case addr.IsLoopback(),
|
||||
addr.IsLinkLocalUnicast(),
|
||||
addr.IsLinkLocalMulticast(),
|
||||
addr.IsInterfaceLocalMulticast(),
|
||||
addr.IsUnspecified(),
|
||||
addr.IsMulticast():
|
||||
|
||||
return netip.Addr{}, nil, ErrRouteNotAllowed
|
||||
}
|
||||
|
||||
// Determine the exit interface and next hop for the prefix, so we can add a specific route
|
||||
nexthop, intf, err := GetNextHop(addr)
|
||||
if err != nil {
|
||||
return netip.Addr{}, nil, fmt.Errorf("get next hop: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf)
|
||||
exitNextHop := nexthop
|
||||
exitIntf := intf
|
||||
|
||||
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
|
||||
if !ok {
|
||||
return netip.Addr{}, nil, fmt.Errorf("failed to convert vpn address to netip.Addr")
|
||||
}
|
||||
|
||||
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
||||
if exitNextHop == vpnAddr || exitIntf != nil && exitIntf.Name == vpnIntf.Name() {
|
||||
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
|
||||
exitNextHop = initialNextHop
|
||||
exitIntf = initialIntf
|
||||
}
|
||||
|
||||
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop)
|
||||
if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil {
|
||||
return netip.Addr{}, nil, fmt.Errorf("add route to table: %w", err)
|
||||
}
|
||||
|
||||
return exitNextHop, exitIntf, nil
|
||||
}
|
||||
|
||||
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
|
||||
// in two /1 prefixes to avoid replacing the existing default route
|
||||
func genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if prefix == defaultv4 {
|
||||
if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
|
||||
if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil {
|
||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: remove once IPv6 is supported on the interface
|
||||
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||
return fmt.Errorf("add unreachable route split 1: %w", err)
|
||||
}
|
||||
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
|
||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||
}
|
||||
return fmt.Errorf("add unreachable route split 2: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
} else if prefix == defaultv6 {
|
||||
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||
return fmt.Errorf("add unreachable route split 1: %w", err)
|
||||
}
|
||||
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
|
||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||
}
|
||||
return fmt.Errorf("add unreachable route split 2: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return addNonExistingRoute(prefix, intf)
|
||||
}
|
||||
|
||||
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
|
||||
func addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
ok, err := existsInRouteTable(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("exists in route table: %w", err)
|
||||
}
|
||||
if ok {
|
||||
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
ok, err = isSubRange(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sub range: %w", err)
|
||||
}
|
||||
|
||||
if ok {
|
||||
err := addRouteForCurrentDefaultGateway(prefix)
|
||||
if err != nil {
|
||||
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return addToRouteTable(prefix, netip.Addr{}, intf)
|
||||
}
|
||||
|
||||
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
|
||||
// it will remove the split /1 prefixes
|
||||
func genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if prefix == defaultv4 {
|
||||
var result *multierror.Error
|
||||
if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
// TODO: remove once IPv6 is supported on the interface
|
||||
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
} else if prefix == defaultv6 {
|
||||
var result *multierror.Error
|
||||
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
return removeFromRouteTable(prefix, netip.Addr{}, intf)
|
||||
}
|
||||
|
||||
func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) {
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("parse IP address: %s", ip)
|
||||
}
|
||||
addr = addr.Unmap()
|
||||
|
||||
var prefixLength int
|
||||
switch {
|
||||
case addr.Is4():
|
||||
prefixLength = 32
|
||||
case addr.Is6():
|
||||
prefixLength = 128
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid IP address: %s", addr)
|
||||
}
|
||||
|
||||
prefix := netip.PrefixFrom(addr, prefixLength)
|
||||
return &prefix, nil
|
||||
}
|
||||
|
||||
func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
initialNextHopV4, initialIntfV4, err := GetNextHop(netip.IPv4Unspecified())
|
||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||
log.Errorf("Unable to get initial v4 default next hop: %v", err)
|
||||
}
|
||||
initialNextHopV6, initialIntfV6, err := GetNextHop(netip.IPv6Unspecified())
|
||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||
log.Errorf("Unable to get initial v6 default next hop: %v", err)
|
||||
}
|
||||
|
||||
*routeManager = NewRouteManager(
|
||||
func(prefix netip.Prefix) (netip.Addr, *net.Interface, error) {
|
||||
addr := prefix.Addr()
|
||||
nexthop, intf := initialNextHopV4, initialIntfV4
|
||||
if addr.Is6() {
|
||||
nexthop, intf = initialNextHopV6, initialIntfV6
|
||||
}
|
||||
return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf)
|
||||
},
|
||||
removeFromRouteTable,
|
||||
)
|
||||
|
||||
return setupHooks(*routeManager, initAddresses)
|
||||
}
|
||||
|
||||
func cleanupRoutingWithRouteManager(routeManager *RouteManager) error {
|
||||
if routeManager == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: Remove hooks selectively
|
||||
nbnet.RemoveDialerHooks()
|
||||
nbnet.RemoveListenerHooks()
|
||||
|
||||
if err := routeManager.Flush(); err != nil {
|
||||
return fmt.Errorf("flush route manager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
||||
prefix, err := getPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert ip to prefix: %w", err)
|
||||
}
|
||||
|
||||
if err := routeManager.AddRouteRef(connID, *prefix); err != nil {
|
||||
return fmt.Errorf("adding route reference: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
afterHook := func(connID nbnet.ConnectionID) error {
|
||||
if err := routeManager.RemoveRouteRef(connID); err != nil {
|
||||
return fmt.Errorf("remove route reference: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, ip := range initAddresses {
|
||||
if err := beforeHook("init", ip); err != nil {
|
||||
log.Errorf("Failed to add route reference: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var result *multierror.Error
|
||||
for _, ip := range resolvedIPs {
|
||||
result = multierror.Append(result, beforeHook(connID, ip.IP))
|
||||
}
|
||||
return result.ErrorOrNil()
|
||||
})
|
||||
|
||||
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
||||
return afterHook(connID)
|
||||
})
|
||||
|
||||
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
|
||||
return beforeHook(connID, ip.IP)
|
||||
})
|
||||
|
||||
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
|
||||
return afterHook(connID)
|
||||
})
|
||||
|
||||
return beforeHook, afterHook, nil
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
//go:build darwin || dragonfly || netbsd || openbsd
|
||||
|
||||
package systemops
|
||||
|
||||
import "syscall"
|
||||
|
||||
// filterRoutesByFlags - return true if need to ignore such route message because it consists specific flags.
|
||||
func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||
if routeMessageFlags&syscall.RTF_UP == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
//go:build: freebsd
|
||||
package systemops
|
||||
|
||||
import "syscall"
|
||||
|
||||
// filterRoutesByFlags - return true if need to ignore such route message because it consists specific flags.
|
||||
func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||
if routeMessageFlags&syscall.RTF_UP == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0 (https://www.freebsd.org/releases/8.0R/relnotes-detailed/)
|
||||
// a concept of cloned route (a route generated by an entry with RTF_CLONING flag) is deprecated.
|
||||
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
type Nexthop struct {
|
||||
IP netip.Addr
|
||||
Intf *net.Interface
|
||||
}
|
||||
|
||||
type ExclusionCounter = refcounter.Counter[any, Nexthop]
|
||||
|
||||
type SysOps struct {
|
||||
refCounter *ExclusionCounter
|
||||
wgInterface *iface.WGIface
|
||||
}
|
||||
|
||||
func NewSysOps(wgInterface *iface.WGIface) *SysOps {
|
||||
return &SysOps{
|
||||
wgInterface: wgInterface,
|
||||
}
|
||||
}
|
||||
@@ -1,473 +0,0 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strconv"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/libp2p/go-netroute"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
|
||||
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
|
||||
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
||||
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
||||
|
||||
var ErrRoutingIsSeparate = errors.New("routing is separate")
|
||||
|
||||
func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
log.Errorf("Unable to get initial v4 default next hop: %v", err)
|
||||
}
|
||||
initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified())
|
||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
log.Errorf("Unable to get initial v6 default next hop: %v", err)
|
||||
}
|
||||
|
||||
refCounter := refcounter.New(
|
||||
func(prefix netip.Prefix, _ any) (Nexthop, error) {
|
||||
initialNexthop := initialNextHopV4
|
||||
if prefix.Addr().Is6() {
|
||||
initialNexthop = initialNextHopV6
|
||||
}
|
||||
|
||||
nexthop, err := r.addRouteToNonVPNIntf(prefix, r.wgInterface, initialNexthop)
|
||||
if errors.Is(err, vars.ErrRouteNotAllowed) || errors.Is(err, vars.ErrRouteNotFound) {
|
||||
log.Tracef("Adding for prefix %s: %v", prefix, err)
|
||||
// These errors are not critical but also we should not track and try to remove the routes either.
|
||||
return nexthop, refcounter.ErrIgnore
|
||||
}
|
||||
return nexthop, err
|
||||
},
|
||||
r.removeFromRouteTable,
|
||||
)
|
||||
|
||||
r.refCounter = refCounter
|
||||
|
||||
return r.setupHooks(initAddresses)
|
||||
}
|
||||
|
||||
func (r *SysOps) cleanupRefCounter() error {
|
||||
if r.refCounter == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: Remove hooks selectively
|
||||
nbnet.RemoveDialerHooks()
|
||||
nbnet.RemoveListenerHooks()
|
||||
|
||||
if err := r.refCounter.Flush(); err != nil {
|
||||
return fmt.Errorf("flush route manager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: fix: for default our wg address now appears as the default gw
|
||||
func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
||||
addr := netip.IPv4Unspecified()
|
||||
if prefix.Addr().Is6() {
|
||||
addr = netip.IPv6Unspecified()
|
||||
}
|
||||
|
||||
nexthop, err := GetNextHop(addr)
|
||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
return fmt.Errorf("get existing route gateway: %s", err)
|
||||
}
|
||||
|
||||
if !prefix.Contains(nexthop.IP) {
|
||||
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", nexthop.IP, prefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
gatewayPrefix := netip.PrefixFrom(nexthop.IP, 32)
|
||||
if nexthop.IP.Is6() {
|
||||
gatewayPrefix = netip.PrefixFrom(nexthop.IP, 128)
|
||||
}
|
||||
|
||||
ok, err := existsInRouteTable(gatewayPrefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
|
||||
}
|
||||
|
||||
if ok {
|
||||
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
nexthop, err = GetNextHop(nexthop.IP)
|
||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
||||
}
|
||||
|
||||
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, nexthop.IP)
|
||||
return r.addToRouteTable(gatewayPrefix, nexthop)
|
||||
}
|
||||
|
||||
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
|
||||
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
|
||||
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop Nexthop) (Nexthop, error) {
|
||||
addr := prefix.Addr()
|
||||
switch {
|
||||
case addr.IsLoopback(),
|
||||
addr.IsLinkLocalUnicast(),
|
||||
addr.IsLinkLocalMulticast(),
|
||||
addr.IsInterfaceLocalMulticast(),
|
||||
addr.IsUnspecified(),
|
||||
addr.IsMulticast():
|
||||
|
||||
return Nexthop{}, vars.ErrRouteNotAllowed
|
||||
}
|
||||
|
||||
// Determine the exit interface and next hop for the prefix, so we can add a specific route
|
||||
nexthop, err := GetNextHop(addr)
|
||||
if err != nil {
|
||||
return Nexthop{}, fmt.Errorf("get next hop: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.IP)
|
||||
exitNextHop := Nexthop{
|
||||
IP: nexthop.IP,
|
||||
Intf: nexthop.Intf,
|
||||
}
|
||||
|
||||
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
|
||||
if !ok {
|
||||
return Nexthop{}, fmt.Errorf("failed to convert vpn address to netip.Addr")
|
||||
}
|
||||
|
||||
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
||||
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
|
||||
log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop)
|
||||
|
||||
exitNextHop = initialNextHop
|
||||
}
|
||||
|
||||
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop.IP)
|
||||
if err := r.addToRouteTable(prefix, exitNextHop); err != nil {
|
||||
return Nexthop{}, fmt.Errorf("add route to table: %w", err)
|
||||
}
|
||||
|
||||
return exitNextHop, nil
|
||||
}
|
||||
|
||||
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
|
||||
// in two /1 prefixes to avoid replacing the existing default route
|
||||
func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
nextHop := Nexthop{netip.Addr{}, intf}
|
||||
|
||||
if prefix == vars.Defaultv4 {
|
||||
if err := r.addToRouteTable(splitDefaultv4_1, nextHop); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.addToRouteTable(splitDefaultv4_2, nextHop); err != nil {
|
||||
if err2 := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err2 != nil {
|
||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: remove once IPv6 is supported on the interface
|
||||
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
||||
return fmt.Errorf("add unreachable route split 1: %w", err)
|
||||
}
|
||||
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
||||
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
|
||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||
}
|
||||
return fmt.Errorf("add unreachable route split 2: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
} else if prefix == vars.Defaultv6 {
|
||||
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
||||
return fmt.Errorf("add unreachable route split 1: %w", err)
|
||||
}
|
||||
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
||||
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
|
||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||
}
|
||||
return fmt.Errorf("add unreachable route split 2: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.addNonExistingRoute(prefix, intf)
|
||||
}
|
||||
|
||||
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
|
||||
func (r *SysOps) addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
ok, err := existsInRouteTable(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("exists in route table: %w", err)
|
||||
}
|
||||
if ok {
|
||||
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
ok, err = isSubRange(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sub range: %w", err)
|
||||
}
|
||||
|
||||
if ok {
|
||||
if err := r.addRouteForCurrentDefaultGateway(prefix); err != nil {
|
||||
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf})
|
||||
}
|
||||
|
||||
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
|
||||
// it will remove the split /1 prefixes
|
||||
func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
nextHop := Nexthop{netip.Addr{}, intf}
|
||||
|
||||
if prefix == vars.Defaultv4 {
|
||||
var result *multierror.Error
|
||||
if err := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := r.removeFromRouteTable(splitDefaultv4_2, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
// TODO: remove once IPv6 is supported on the interface
|
||||
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
} else if prefix == vars.Defaultv6 {
|
||||
var result *multierror.Error
|
||||
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
return r.removeFromRouteTable(prefix, nextHop)
|
||||
}
|
||||
|
||||
func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
||||
prefix, err := util.GetPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert ip to prefix: %w", err)
|
||||
}
|
||||
|
||||
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, nil); err != nil {
|
||||
return fmt.Errorf("adding route reference: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
afterHook := func(connID nbnet.ConnectionID) error {
|
||||
if err := r.refCounter.DecrementWithID(string(connID)); err != nil {
|
||||
return fmt.Errorf("remove route reference: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, ip := range initAddresses {
|
||||
if err := beforeHook("init", ip); err != nil {
|
||||
log.Errorf("Failed to add route reference: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var result *multierror.Error
|
||||
for _, ip := range resolvedIPs {
|
||||
result = multierror.Append(result, beforeHook(connID, ip.IP))
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
})
|
||||
|
||||
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
||||
return afterHook(connID)
|
||||
})
|
||||
|
||||
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
|
||||
return beforeHook(connID, ip.IP)
|
||||
})
|
||||
|
||||
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
|
||||
return afterHook(connID)
|
||||
})
|
||||
|
||||
return beforeHook, afterHook, nil
|
||||
}
|
||||
|
||||
func GetNextHop(ip netip.Addr) (Nexthop, error) {
|
||||
r, err := netroute.New()
|
||||
if err != nil {
|
||||
return Nexthop{}, fmt.Errorf("new netroute: %w", err)
|
||||
}
|
||||
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
|
||||
if err != nil {
|
||||
log.Debugf("Failed to get route for %s: %v", ip, err)
|
||||
return Nexthop{}, vars.ErrRouteNotFound
|
||||
}
|
||||
|
||||
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
|
||||
if gateway == nil {
|
||||
if runtime.GOOS == "freebsd" {
|
||||
return Nexthop{Intf: intf}, nil
|
||||
}
|
||||
|
||||
if preferredSrc == nil {
|
||||
return Nexthop{}, vars.ErrRouteNotFound
|
||||
}
|
||||
log.Debugf("No next hop found for IP %s, using preferred source %s", ip, preferredSrc)
|
||||
|
||||
addr, err := ipToAddr(preferredSrc, intf)
|
||||
if err != nil {
|
||||
return Nexthop{}, fmt.Errorf("convert preferred source to address: %w", err)
|
||||
}
|
||||
return Nexthop{
|
||||
IP: addr,
|
||||
Intf: intf,
|
||||
}, nil
|
||||
}
|
||||
|
||||
addr, err := ipToAddr(gateway, intf)
|
||||
if err != nil {
|
||||
return Nexthop{}, fmt.Errorf("convert gateway to address: %w", err)
|
||||
}
|
||||
|
||||
return Nexthop{
|
||||
IP: addr,
|
||||
Intf: intf,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// converts a net.IP to a netip.Addr including the zone based on the passed interface
|
||||
func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip)
|
||||
}
|
||||
|
||||
if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) {
|
||||
zone := intf.Name
|
||||
if runtime.GOOS == "windows" {
|
||||
zone = strconv.Itoa(intf.Index)
|
||||
}
|
||||
log.Tracef("Adding zone %s to address %s", zone, addr)
|
||||
addr = addr.WithZone(zone)
|
||||
}
|
||||
|
||||
return addr.Unmap(), nil
|
||||
}
|
||||
|
||||
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
||||
routes, err := getRoutesFromTable()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get routes from table: %w", err)
|
||||
}
|
||||
for _, tableRoute := range routes {
|
||||
if tableRoute == prefix {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func isSubRange(prefix netip.Prefix) (bool, error) {
|
||||
routes, err := getRoutesFromTable()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get routes from table: %w", err)
|
||||
}
|
||||
for _, tableRoute := range routes {
|
||||
if tableRoute.Bits() > vars.MinRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
|
||||
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
|
||||
localRoutes, err := hasSeparateRouting()
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrRoutingIsSeparate) {
|
||||
log.Errorf("Failed to get routes: %v", err)
|
||||
}
|
||||
return false, netip.Prefix{}
|
||||
}
|
||||
|
||||
return isVpnRoute(addr, vpnRoutes, localRoutes)
|
||||
}
|
||||
|
||||
func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.Prefix) (bool, netip.Prefix) {
|
||||
vpnPrefixMap := map[netip.Prefix]struct{}{}
|
||||
for _, prefix := range vpnRoutes {
|
||||
vpnPrefixMap[prefix] = struct{}{}
|
||||
}
|
||||
|
||||
// remove vpnRoute duplicates
|
||||
for _, prefix := range localRoutes {
|
||||
delete(vpnPrefixMap, prefix)
|
||||
}
|
||||
|
||||
var longestPrefix netip.Prefix
|
||||
var isVpn bool
|
||||
|
||||
combinedRoutes := make([]netip.Prefix, len(vpnRoutes)+len(localRoutes))
|
||||
copy(combinedRoutes, vpnRoutes)
|
||||
copy(combinedRoutes[len(vpnRoutes):], localRoutes)
|
||||
|
||||
for _, prefix := range combinedRoutes {
|
||||
// Ignore the default route, it has special handling
|
||||
if prefix.Bits() == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if prefix.Contains(addr) {
|
||||
// Longest prefix match
|
||||
if !longestPrefix.IsValid() || prefix.Bits() > longestPrefix.Bits() {
|
||||
longestPrefix = prefix
|
||||
_, isVpn = vpnPrefixMap[prefix]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !longestPrefix.IsValid() {
|
||||
// No route matched
|
||||
return false, netip.Prefix{}
|
||||
}
|
||||
|
||||
// Return true if the longest matching prefix is from vpnRoutes
|
||||
return isVpn, longestPrefix
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
//go:build ios || android
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) AddVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) RemoveVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func EnableIPForwarding() error {
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
|
||||
func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) {
|
||||
return false, netip.Prefix{}
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
//go:build !linux && !ios
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
return r.genericAddVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
return r.genericRemoveVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
func EnableIPForwarding() error {
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
|
||||
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||
return getRoutesFromTable()
|
||||
}
|
||||
33
client/internal/routemanager/systemops_android.go
Normal file
33
client/internal/routemanager/systemops_android.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func cleanupRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func enableIPForwarding() error {
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
|
||||
func addVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package systemops
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@@ -43,7 +43,8 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
return nil, fmt.Errorf("unexpected RIB message type: %d", m.Type)
|
||||
}
|
||||
|
||||
if filterRoutesByFlags(m.Flags) {
|
||||
if m.Flags&syscall.RTF_UP == 0 ||
|
||||
m.Flags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -92,7 +93,7 @@ func toNetIP(a route.Addr) netip.Addr {
|
||||
case *route.Inet6Addr:
|
||||
ip := netip.AddrFrom16(t.IP)
|
||||
if t.ZoneID != 0 {
|
||||
ip = ip.WithZone(strconv.Itoa(t.ZoneID))
|
||||
ip.WithZone(strconv.Itoa(t.ZoneID))
|
||||
}
|
||||
return ip
|
||||
default:
|
||||
@@ -100,7 +101,6 @@ func toNetIP(a route.Addr) netip.Addr {
|
||||
}
|
||||
}
|
||||
|
||||
// ones returns the number of leading ones in the mask.
|
||||
func ones(a route.Addr) (int, error) {
|
||||
switch t := a.(type) {
|
||||
case *route.Inet4Addr:
|
||||
@@ -114,7 +114,6 @@ func ones(a route.Addr) (int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// MsgToRoute converts a route message to a Route.
|
||||
func MsgToRoute(msg *route.RouteMessage) (*Route, error) {
|
||||
dstIP, nexthop, dstMask := msg.Addrs[0], msg.Addrs[1], msg.Addrs[2]
|
||||
|
||||
57
client/internal/routemanager/systemops_bsd_test.go
Normal file
57
client/internal/routemanager/systemops_bsd_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/net/route"
|
||||
)
|
||||
|
||||
func TestBits(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr route.Addr
|
||||
want int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "IPv4 all ones",
|
||||
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 255}},
|
||||
want: 32,
|
||||
},
|
||||
{
|
||||
name: "IPv4 normal mask",
|
||||
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 0}},
|
||||
want: 24,
|
||||
},
|
||||
{
|
||||
name: "IPv6 all ones",
|
||||
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}},
|
||||
want: 128,
|
||||
},
|
||||
{
|
||||
name: "IPv6 normal mask",
|
||||
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0}},
|
||||
want: 64,
|
||||
},
|
||||
{
|
||||
name: "Unsupported type",
|
||||
addr: &route.LinkAddr{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ones(tt.addr)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd
|
||||
//go:build darwin && !ios
|
||||
|
||||
package systemops
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -13,41 +13,43 @@ import (
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
return r.setupRefCounter(initAddresses)
|
||||
var routeManager *RouteManager
|
||||
|
||||
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting() error {
|
||||
return r.cleanupRefCounter()
|
||||
func cleanupRouting() error {
|
||||
return cleanupRoutingWithRouteManager(routeManager)
|
||||
}
|
||||
|
||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return r.routeCmd("add", prefix, nexthop)
|
||||
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
return routeCmd("add", prefix, nexthop, intf)
|
||||
}
|
||||
|
||||
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return r.routeCmd("delete", prefix, nexthop)
|
||||
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
return routeCmd("delete", prefix, nexthop, intf)
|
||||
}
|
||||
|
||||
func (r *SysOps) routeCmd(action string, prefix netip.Prefix, nexthop Nexthop) error {
|
||||
func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
inet := "-inet"
|
||||
if prefix.Addr().Is6() {
|
||||
inet = "-inet6"
|
||||
}
|
||||
|
||||
network := prefix.String()
|
||||
if prefix.IsSingleIP() {
|
||||
network = prefix.Addr().String()
|
||||
}
|
||||
if prefix.Addr().Is6() {
|
||||
inet = "-inet6"
|
||||
}
|
||||
|
||||
args := []string{"-n", action, inet, network}
|
||||
if nexthop.IP.IsValid() {
|
||||
args = append(args, nexthop.IP.Unmap().String())
|
||||
} else if nexthop.Intf != nil {
|
||||
args = append(args, "-interface", nexthop.Intf.Name)
|
||||
if nexthop.IsValid() {
|
||||
args = append(args, nexthop.Unmap().String())
|
||||
} else if intf != nil {
|
||||
args = append(args, "-interface", intf.Name)
|
||||
}
|
||||
|
||||
if err := retryRouteCmd(args); err != nil {
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||
//go:build !ios
|
||||
|
||||
package systemops
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/route"
|
||||
)
|
||||
|
||||
var expectedVPNint = "utun100"
|
||||
@@ -36,15 +35,13 @@ func TestConcurrentRoutes(t *testing.T) {
|
||||
baseIP := netip.MustParseAddr("192.0.2.0")
|
||||
intf := &net.Interface{Name: "lo0"}
|
||||
|
||||
r := NewSysOps(nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 1024; i++ {
|
||||
wg.Add(1)
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
|
||||
if err := addToRouteTable(prefix, netip.Addr{}, intf); err != nil {
|
||||
t.Errorf("Failed to add route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
@@ -60,7 +57,7 @@ func TestConcurrentRoutes(t *testing.T) {
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := r.removeFromRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
|
||||
if err := removeFromRouteTable(prefix, netip.Addr{}, intf); err != nil {
|
||||
t.Errorf("Failed to remove route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
@@ -70,53 +67,6 @@ func TestConcurrentRoutes(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestBits(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr route.Addr
|
||||
want int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "IPv4 all ones",
|
||||
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 255}},
|
||||
want: 32,
|
||||
},
|
||||
{
|
||||
name: "IPv4 normal mask",
|
||||
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 0}},
|
||||
want: 24,
|
||||
},
|
||||
{
|
||||
name: "IPv6 all ones",
|
||||
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}},
|
||||
want: 128,
|
||||
},
|
||||
{
|
||||
name: "IPv6 normal mask",
|
||||
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0}},
|
||||
want: 64,
|
||||
},
|
||||
{
|
||||
name: "Unsupported type",
|
||||
addr: &route.LinkAddr{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ones(tt.addr)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
|
||||
t.Helper()
|
||||
|
||||
33
client/internal/routemanager/systemops_ios.go
Normal file
33
client/internal/routemanager/systemops_ios.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func cleanupRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func enableIPForwarding(includeV6 bool) error {
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
|
||||
func addVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build !android
|
||||
|
||||
package systemops
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -9,15 +9,16 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
@@ -32,10 +33,16 @@ const (
|
||||
|
||||
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
|
||||
ipv4ForwardingPath = "net.ipv4.ip_forward"
|
||||
|
||||
rpFilterPath = "net.ipv4.conf.all.rp_filter"
|
||||
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
|
||||
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
|
||||
)
|
||||
|
||||
var ErrTableIDExists = errors.New("ID exists with different name")
|
||||
|
||||
var routeManager = &RouteManager{}
|
||||
|
||||
// originalSysctl stores the original sysctl values before they are modified
|
||||
var originalSysctl map[string]int
|
||||
|
||||
@@ -75,7 +82,7 @@ func getSetupRules() []ruleParams {
|
||||
}
|
||||
}
|
||||
|
||||
// SetupRouting establishes the routing configuration for the VPN, including essential rules
|
||||
// setupRouting establishes the routing configuration for the VPN, including essential rules
|
||||
// to ensure proper traffic flow for management, locally configured routes, and VPN traffic.
|
||||
//
|
||||
// Rule 1 (Main Route Precedence): Safeguards locally installed routes by giving them precedence over
|
||||
@@ -85,17 +92,17 @@ func getSetupRules() []ruleParams {
|
||||
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
||||
// This table is where a default route or other specific routes received from the management server are configured,
|
||||
// enabling VPN connectivity.
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) {
|
||||
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
|
||||
if isLegacy() {
|
||||
log.Infof("Using legacy routing setup")
|
||||
return r.setupRefCounter(initAddresses)
|
||||
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||
}
|
||||
|
||||
if err = addRoutingTableName(); err != nil {
|
||||
log.Errorf("Error adding routing table name: %v", err)
|
||||
}
|
||||
|
||||
originalValues, err := sysctl.Setup(r.wgInterface)
|
||||
originalValues, err := setupSysctl(wgIface)
|
||||
if err != nil {
|
||||
log.Errorf("Error setting up sysctl: %v", err)
|
||||
sysctlFailed = true
|
||||
@@ -104,7 +111,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if cleanErr := r.CleanupRouting(); cleanErr != nil {
|
||||
if cleanErr := cleanupRouting(); cleanErr != nil {
|
||||
log.Errorf("Error cleaning up routing: %v", cleanErr)
|
||||
}
|
||||
}
|
||||
@@ -116,7 +123,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb
|
||||
if errors.Is(err, syscall.EOPNOTSUPP) {
|
||||
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
|
||||
setIsLegacy(true)
|
||||
return r.setupRefCounter(initAddresses)
|
||||
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||
}
|
||||
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
|
||||
}
|
||||
@@ -125,12 +132,12 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
||||
// cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
||||
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
||||
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
||||
func (r *SysOps) CleanupRouting() error {
|
||||
func cleanupRouting() error {
|
||||
if isLegacy() {
|
||||
return r.cleanupRefCounter()
|
||||
return cleanupRoutingWithRouteManager(routeManager)
|
||||
}
|
||||
|
||||
var result *multierror.Error
|
||||
@@ -149,58 +156,46 @@ func (r *SysOps) CleanupRouting() error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := sysctl.Cleanup(originalSysctl); err != nil {
|
||||
if err := cleanupSysctl(originalSysctl); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err))
|
||||
}
|
||||
originalSysctl = nil
|
||||
sysctlFailed = false
|
||||
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return addRoute(prefix, nexthop, syscall.RT_TABLE_MAIN)
|
||||
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
|
||||
}
|
||||
|
||||
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return removeRoute(prefix, nexthop, syscall.RT_TABLE_MAIN)
|
||||
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
|
||||
}
|
||||
|
||||
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if isLegacy() {
|
||||
return r.genericAddVPNRoute(prefix, intf)
|
||||
return genericAddVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
if sysctlFailed && (prefix == vars.Defaultv4 || prefix == vars.Defaultv6) {
|
||||
if sysctlFailed && (prefix == defaultv4 || prefix == defaultv6) {
|
||||
log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)")
|
||||
}
|
||||
|
||||
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
|
||||
|
||||
// TODO remove this once we have ipv6 support
|
||||
if prefix == vars.Defaultv4 {
|
||||
if err := addUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil {
|
||||
return fmt.Errorf("add blackhole: %w", err)
|
||||
}
|
||||
}
|
||||
if err := addRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil {
|
||||
if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
|
||||
return fmt.Errorf("add route: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if isLegacy() {
|
||||
return r.genericRemoveVPNRoute(prefix, intf)
|
||||
return genericRemoveVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
// TODO remove this once we have ipv6 support
|
||||
if prefix == vars.Defaultv4 {
|
||||
if err := removeUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil {
|
||||
return fmt.Errorf("remove unreachable route: %w", err)
|
||||
}
|
||||
}
|
||||
if err := removeRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil {
|
||||
if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
|
||||
return fmt.Errorf("remove route: %w", err)
|
||||
}
|
||||
return nil
|
||||
@@ -248,7 +243,7 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) {
|
||||
}
|
||||
|
||||
// addRoute adds a route to a specific routing table identified by tableID.
|
||||
func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
|
||||
func addRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
|
||||
route := &netlink.Route{
|
||||
Scope: netlink.SCOPE_UNIVERSE,
|
||||
Table: tableID,
|
||||
@@ -261,7 +256,7 @@ func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
|
||||
}
|
||||
route.Dst = ipNet
|
||||
|
||||
if err := addNextHop(nexthop, route); err != nil {
|
||||
if err := addNextHop(addr, intf, route); err != nil {
|
||||
return fmt.Errorf("add gateway and device: %w", err)
|
||||
}
|
||||
|
||||
@@ -275,6 +270,9 @@ func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
|
||||
// addUnreachableRoute adds an unreachable route for the specified IP family and routing table.
|
||||
// ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6.
|
||||
// tableID specifies the routing table to which the unreachable route will be added.
|
||||
// TODO should this be kept in for future use? If so, the linter needs to be told that this unreachable function should
|
||||
//
|
||||
// be kept
|
||||
func addUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||
if err != nil {
|
||||
@@ -295,6 +293,9 @@ func addUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO should this be kept in for future use? If so, the linter needs to be told that this unreachable function should
|
||||
//
|
||||
// be kept
|
||||
func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||
if err != nil {
|
||||
@@ -320,7 +321,7 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
||||
}
|
||||
|
||||
// removeRoute removes a route from a specific routing table identified by tableID.
|
||||
func removeRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
|
||||
func removeRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
|
||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||
@@ -333,7 +334,7 @@ func removeRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
|
||||
Dst: ipNet,
|
||||
}
|
||||
|
||||
if err := addNextHop(nexthop, route); err != nil {
|
||||
if err := addNextHop(addr, intf, route); err != nil {
|
||||
return fmt.Errorf("add gateway and device: %w", err)
|
||||
}
|
||||
|
||||
@@ -366,11 +367,17 @@ func flushRoutes(tableID, family int) error {
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
func EnableIPForwarding() error {
|
||||
_, err := sysctl.Set(ipv4ForwardingPath, 1, false)
|
||||
func enableIPForwarding(includeV6 bool) error {
|
||||
_, err := setSysctl(ipv4ForwardingPath, 1, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if includeV6 {
|
||||
_, err = setSysctl(ipv4ForwardingPath, 1, false)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -474,19 +481,19 @@ func removeRule(params ruleParams) error {
|
||||
}
|
||||
|
||||
// addNextHop adds the gateway and device to the route.
|
||||
func addNextHop(nexthop Nexthop, route *netlink.Route) error {
|
||||
if nexthop.Intf != nil {
|
||||
route.LinkIndex = nexthop.Intf.Index
|
||||
func addNextHop(addr netip.Addr, intf *net.Interface, route *netlink.Route) error {
|
||||
if intf != nil {
|
||||
route.LinkIndex = intf.Index
|
||||
}
|
||||
|
||||
if nexthop.IP.IsValid() {
|
||||
route.Gw = nexthop.IP.AsSlice()
|
||||
if addr.IsValid() {
|
||||
route.Gw = addr.AsSlice()
|
||||
|
||||
// if zone is set, it means the gateway is a link-local address, so we set the link index
|
||||
if nexthop.IP.Zone() != "" && nexthop.Intf == nil {
|
||||
link, err := netlink.LinkByName(nexthop.IP.Zone())
|
||||
if addr.Zone() != "" && intf == nil {
|
||||
link, err := netlink.LinkByName(addr.Zone())
|
||||
if err != nil {
|
||||
return fmt.Errorf("get link by name for zone %s: %w", nexthop.IP.Zone(), err)
|
||||
return fmt.Errorf("get link by name for zone %s: %w", addr.Zone(), err)
|
||||
}
|
||||
route.LinkIndex = link.Attrs().Index
|
||||
}
|
||||
@@ -502,9 +509,82 @@ func getAddressFamily(prefix netip.Prefix) int {
|
||||
return netlink.FAMILY_V6
|
||||
}
|
||||
|
||||
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||
if isLegacy() {
|
||||
return getRoutesFromTable()
|
||||
// setupSysctl configures sysctl settings for RP filtering and source validation.
|
||||
func setupSysctl(wgIface *iface.WGIface) (map[string]int, error) {
|
||||
keys := map[string]int{}
|
||||
var result *multierror.Error
|
||||
|
||||
oldVal, err := setSysctl(srcValidMarkPath, 1, false)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[srcValidMarkPath] = oldVal
|
||||
}
|
||||
return nil, ErrRoutingIsSeparate
|
||||
|
||||
oldVal, err = setSysctl(rpFilterPath, 2, true)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[rpFilterPath] = oldVal
|
||||
}
|
||||
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
|
||||
}
|
||||
|
||||
for _, intf := range interfaces {
|
||||
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
|
||||
continue
|
||||
}
|
||||
|
||||
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
|
||||
oldVal, err := setSysctl(i, 2, true)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[i] = oldVal
|
||||
}
|
||||
}
|
||||
|
||||
return keys, result.ErrorOrNil()
|
||||
}
|
||||
|
||||
// setSysctl sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
|
||||
func setSysctl(key string, desiredValue int, onlyIfOne bool) (int, error) {
|
||||
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
|
||||
currentValue, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
|
||||
}
|
||||
|
||||
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
|
||||
if err != nil && len(currentValue) > 0 {
|
||||
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
|
||||
}
|
||||
|
||||
if currentV == desiredValue || onlyIfOne && currentV != 1 {
|
||||
return currentV, nil
|
||||
}
|
||||
|
||||
//nolint:gosec
|
||||
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
|
||||
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
|
||||
}
|
||||
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
|
||||
|
||||
return currentV, nil
|
||||
}
|
||||
|
||||
func cleanupSysctl(originalSettings map[string]int) error {
|
||||
var result *multierror.Error
|
||||
|
||||
for key, value := range originalSettings {
|
||||
_, err := setSysctl(key, value, false)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build !android
|
||||
|
||||
package systemops
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@@ -14,8 +14,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vishvananda/netlink"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
)
|
||||
|
||||
var expectedVPNint = "wgtest0"
|
||||
@@ -140,7 +138,7 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) {
|
||||
if dstIPNet.String() == "0.0.0.0/0" {
|
||||
var err error
|
||||
originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4)
|
||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||
t.Logf("Failed to fetch original gateway: %v", err)
|
||||
}
|
||||
|
||||
@@ -195,7 +193,7 @@ func fetchOriginalGateway(family int) (net.IP, int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return nil, 0, vars.ErrRouteNotFound
|
||||
return nil, 0, ErrRouteNotFound
|
||||
}
|
||||
|
||||
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||
24
client/internal/routemanager/systemops_nonlinux.go
Normal file
24
client/internal/routemanager/systemops_nonlinux.go
Normal file
@@ -0,0 +1,24 @@
|
||||
//go:build !linux && !ios
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func enableIPForwarding(includeV6 bool) error {
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
|
||||
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
return genericAddVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
return genericRemoveVPNRoute(prefix, intf)
|
||||
}
|
||||
@@ -1,11 +1,13 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package systemops
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/google/gopacket/routing"
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
@@ -46,41 +48,55 @@ func TestAddRemoveRoutes(t *testing.T) {
|
||||
shouldRouteToWireguard: false,
|
||||
shouldBeRemoved: false,
|
||||
},
|
||||
{
|
||||
name: "Should Add And Remove Route 2001:db8:1234:5678::/64",
|
||||
prefix: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
|
||||
shouldRouteToWireguard: true,
|
||||
shouldBeRemoved: true,
|
||||
},
|
||||
{
|
||||
name: "Should Not Add Or Remove Route ::1/128",
|
||||
prefix: netip.MustParsePrefix("::1/128"),
|
||||
shouldRouteToWireguard: false,
|
||||
shouldBeRemoved: false,
|
||||
},
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
// todo resolve test execution on freebsd
|
||||
if runtime.GOOS == "freebsd" {
|
||||
t.Skip("skipping ", testCase.name, " on freebsd")
|
||||
}
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
|
||||
|
||||
v6Addr := ""
|
||||
hasV6DefaultRoute, err := EnvironmentHasIPv6DefaultRoute()
|
||||
//goland:noinspection GoBoolExpressions
|
||||
if (!iface.SupportsIPv6() || !firewall.SupportsIPv6() || !hasV6DefaultRoute || err != nil) && testCase.prefix.Addr().Is6() {
|
||||
t.Skip("Platform does not support IPv6, skipping IPv6 test...")
|
||||
} else if testCase.prefix.Addr().Is6() {
|
||||
v6Addr = "2001:db8::4242:4711/128"
|
||||
}
|
||||
|
||||
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", v6Addr, 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
err = wgInterface.Create()
|
||||
require.NoError(t, err, "should create testing wireguard interface")
|
||||
|
||||
r := NewSysOps(wgInterface)
|
||||
|
||||
_, _, err = r.SetupRouting(nil)
|
||||
_, _, err = setupRouting(nil, wgInterface)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, r.CleanupRouting())
|
||||
assert.NoError(t, cleanupRouting())
|
||||
})
|
||||
|
||||
index, err := net.InterfaceByName(wgInterface.Name())
|
||||
require.NoError(t, err, "InterfaceByName should not return err")
|
||||
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
|
||||
|
||||
err = r.AddVPNRoute(testCase.prefix, intf)
|
||||
err = addVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "genericAddVPNRoute should not return err")
|
||||
|
||||
if testCase.shouldRouteToWireguard {
|
||||
@@ -91,19 +107,23 @@ func TestAddRemoveRoutes(t *testing.T) {
|
||||
exists, err := existsInRouteTable(testCase.prefix)
|
||||
require.NoError(t, err, "existsInRouteTable should not return err")
|
||||
if exists && testCase.shouldRouteToWireguard {
|
||||
err = r.RemoveVPNRoute(testCase.prefix, intf)
|
||||
err = removeVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
|
||||
|
||||
prefixNexthop, err := GetNextHop(testCase.prefix.Addr())
|
||||
prefixGateway, _, err := GetNextHop(testCase.prefix.Addr())
|
||||
require.NoError(t, err, "GetNextHop should not return err")
|
||||
|
||||
internetNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
internetGateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
require.NoError(t, err)
|
||||
if testCase.prefix.Addr().Is6() {
|
||||
internetGateway, _, err = GetNextHop(netip.MustParseAddr("::/0"))
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
if testCase.shouldBeRemoved {
|
||||
require.Equal(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to default internet gateway")
|
||||
require.Equal(t, internetGateway, prefixGateway, "route should be pointing to default internet gateway")
|
||||
} else {
|
||||
require.NotEqual(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to a different gateway than the internet gateway")
|
||||
require.NotEqual(t, internetGateway, prefixGateway, "route should be pointing to a different gateway than the internet gateway")
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -111,14 +131,11 @@ func TestAddRemoveRoutes(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetNextHop(t *testing.T) {
|
||||
if runtime.GOOS == "freebsd" {
|
||||
t.Skip("skipping on freebsd")
|
||||
}
|
||||
nexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
gateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||
}
|
||||
if !nexthop.IP.IsValid() {
|
||||
if !gateway.IsValid() {
|
||||
t.Fatal("should return a gateway")
|
||||
}
|
||||
addresses, err := net.InterfaceAddrs()
|
||||
@@ -140,27 +157,39 @@ func TestGetNextHop(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
localIP, err := GetNextHop(testingPrefix.Addr())
|
||||
localIP, _, err := GetNextHop(testingPrefix.Addr())
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error: ", err)
|
||||
}
|
||||
if !localIP.IP.IsValid() {
|
||||
if !localIP.IsValid() {
|
||||
t.Fatal("should return a gateway for local network")
|
||||
}
|
||||
if localIP.IP.String() == nexthop.IP.String() {
|
||||
t.Fatal("local IP should not match with gateway IP")
|
||||
if localIP.String() == gateway.String() {
|
||||
t.Fatal("local ip should not match with gateway IP")
|
||||
}
|
||||
if localIP.IP.String() != testingIP {
|
||||
t.Fatalf("local IP should match with testing IP: want %s got %s", testingIP, localIP.IP.String())
|
||||
if localIP.String() != testingIP {
|
||||
t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||
defaultNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
t.Log("defaultNexthop: ", defaultNexthop)
|
||||
defaultGateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
t.Log("defaultGateway: ", defaultGateway)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||
}
|
||||
var defaultGateway6 *netip.Addr
|
||||
hasV6DefaultRoute, err := EnvironmentHasIPv6DefaultRoute()
|
||||
//goland:noinspection GoBoolExpressions
|
||||
if iface.SupportsIPv6() && firewall.SupportsIPv6() && hasV6DefaultRoute && err == nil {
|
||||
gw6, _, err := GetNextHop(netip.MustParseAddr("::"))
|
||||
gw6 = gw6.WithZone("")
|
||||
defaultGateway6 = &gw6
|
||||
t.Log("defaultGateway6: ", defaultGateway6)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching the IPv6 gateway: ", err)
|
||||
}
|
||||
}
|
||||
testCases := []struct {
|
||||
name string
|
||||
prefix netip.Prefix
|
||||
@@ -174,7 +203,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Should Not Add Route if overlaps with default gateway",
|
||||
prefix: netip.MustParsePrefix(defaultNexthop.IP.String() + "/31"),
|
||||
prefix: netip.MustParsePrefix(defaultGateway.String() + "/31"),
|
||||
shouldAddRoute: false,
|
||||
},
|
||||
{
|
||||
@@ -195,6 +224,43 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
|
||||
shouldAddRoute: false,
|
||||
},
|
||||
{
|
||||
name: "Should Add And Remove random Route (IPv6)",
|
||||
prefix: netip.MustParsePrefix("2001:db8::abcd/128"),
|
||||
shouldAddRoute: true,
|
||||
},
|
||||
{
|
||||
name: "Should Add Route if bigger network exists (IPv6)",
|
||||
prefix: netip.MustParsePrefix("2001:db8:b14d:abcd:1234::/96"),
|
||||
preExistingPrefix: netip.MustParsePrefix("2001:db8:b14d:abcd::/64"),
|
||||
shouldAddRoute: true,
|
||||
},
|
||||
{
|
||||
name: "Should Add Route if smaller network exists (IPv6)",
|
||||
prefix: netip.MustParsePrefix("2001:db8:b14d::/48"),
|
||||
preExistingPrefix: netip.MustParsePrefix("2001:db8:b14d:abcd::/64"),
|
||||
shouldAddRoute: true,
|
||||
},
|
||||
{
|
||||
name: "Should Not Add Route if same network exists (IPv6)",
|
||||
prefix: netip.MustParsePrefix("2001:db8:b14d:abcd::/64"),
|
||||
preExistingPrefix: netip.MustParsePrefix("2001:db8:b14d:abcd::/64"),
|
||||
shouldAddRoute: false,
|
||||
},
|
||||
}
|
||||
if defaultGateway6 != nil {
|
||||
testCases = append(testCases, []struct {
|
||||
name string
|
||||
prefix netip.Prefix
|
||||
preExistingPrefix netip.Prefix
|
||||
shouldAddRoute bool
|
||||
}{
|
||||
{
|
||||
name: "Should Not Add Route if overlaps with default gateway (IPv6)",
|
||||
prefix: netip.MustParsePrefix(defaultGateway6.String() + "/127"),
|
||||
shouldAddRoute: false,
|
||||
},
|
||||
}...)
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
@@ -208,12 +274,19 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||
t.Setenv("NB_USE_LEGACY_ROUTING", "true")
|
||||
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
|
||||
|
||||
v6Addr := ""
|
||||
if testCase.prefix.Addr().Is6() && defaultGateway6 == nil {
|
||||
t.Skip("Platform does not support IPv6, skipping IPv6 test...")
|
||||
} else if testCase.prefix.Addr().Is6() {
|
||||
v6Addr = "2001:db8::4242:4711/128"
|
||||
}
|
||||
|
||||
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", v6Addr, 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
@@ -224,16 +297,14 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||
require.NoError(t, err, "InterfaceByName should not return err")
|
||||
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
|
||||
|
||||
r := NewSysOps(wgInterface)
|
||||
|
||||
// Prepare the environment
|
||||
if testCase.preExistingPrefix.IsValid() {
|
||||
err := r.AddVPNRoute(testCase.preExistingPrefix, intf)
|
||||
err := addVPNRoute(testCase.preExistingPrefix, intf)
|
||||
require.NoError(t, err, "should not return err when adding pre-existing route")
|
||||
}
|
||||
|
||||
// Add the route
|
||||
err = r.AddVPNRoute(testCase.prefix, intf)
|
||||
err = addVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "should not return err when adding route")
|
||||
|
||||
if testCase.shouldAddRoute {
|
||||
@@ -243,7 +314,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||
require.True(t, ok, "route should exist")
|
||||
|
||||
// remove route again if added
|
||||
err = r.RemoveVPNRoute(testCase.prefix, intf)
|
||||
err = removeVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "should not return err")
|
||||
}
|
||||
|
||||
@@ -261,6 +332,11 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIsSubRange(t *testing.T) {
|
||||
// Note: This test may fail for IPv6 in some environments, where there actually exists another route that the
|
||||
// determined prefix is a sub-range of.
|
||||
hasV6DefaultRoute, err := EnvironmentHasIPv6DefaultRoute()
|
||||
shouldIncludeV6Routes := iface.SupportsIPv6() && firewall.SupportsIPv6() && hasV6DefaultRoute && err == nil
|
||||
|
||||
addresses, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
||||
@@ -270,7 +346,7 @@ func TestIsSubRange(t *testing.T) {
|
||||
var nonSubRangeAddressPrefixes []netip.Prefix
|
||||
for _, address := range addresses {
|
||||
p := netip.MustParsePrefix(address.String())
|
||||
if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 {
|
||||
if !p.Addr().IsLoopback() && (p.Addr().Is4() && p.Bits() < 32) || (p.Addr().Is6() && shouldIncludeV6Routes && p.Bits() < 128) {
|
||||
p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1)
|
||||
subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2)
|
||||
nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked())
|
||||
@@ -298,31 +374,49 @@ func TestIsSubRange(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func EnvironmentHasIPv6DefaultRoute() (bool, error) {
|
||||
//goland:noinspection GoBoolExpressions
|
||||
if runtime.GOOS != "linux" {
|
||||
// TODO when implementing IPv6 for other operating systems, this should be replaced with code that determines
|
||||
// whether a default route for IPv6 exists (routing.Router panics on non-linux).
|
||||
return false, nil
|
||||
}
|
||||
router, err := routing.New()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
routeIface, _, _, err := router.Route(netip.MustParsePrefix("::/0").Addr().AsSlice())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return routeIface != nil, nil
|
||||
}
|
||||
|
||||
func TestExistsInRouteTable(t *testing.T) {
|
||||
addresses, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
||||
}
|
||||
|
||||
hasV6DefaultRoute, err := EnvironmentHasIPv6DefaultRoute()
|
||||
shouldIncludeV6Routes := iface.SupportsIPv6() && firewall.SupportsIPv6() && hasV6DefaultRoute && err == nil
|
||||
|
||||
var addressPrefixes []netip.Prefix
|
||||
for _, address := range addresses {
|
||||
p := netip.MustParsePrefix(address.String())
|
||||
|
||||
switch {
|
||||
case p.Addr().Is6():
|
||||
if p.Addr().Is6() && !shouldIncludeV6Routes {
|
||||
continue
|
||||
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
|
||||
case runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast():
|
||||
continue
|
||||
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
|
||||
case runtime.GOOS == "linux" && p.Addr().IsLoopback():
|
||||
continue
|
||||
// FreeBSD loopback 127/8 is not added to the routing table
|
||||
case runtime.GOOS == "freebsd" && p.Addr().IsLoopback():
|
||||
continue
|
||||
default:
|
||||
addressPrefixes = append(addressPrefixes, p.Masked())
|
||||
}
|
||||
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
|
||||
if runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast() {
|
||||
continue
|
||||
}
|
||||
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
|
||||
if runtime.GOOS == "linux" && p.Addr().IsLoopback() {
|
||||
continue
|
||||
}
|
||||
|
||||
addressPrefixes = append(addressPrefixes, p.Masked())
|
||||
}
|
||||
|
||||
for _, prefix := range addressPrefixes {
|
||||
@@ -336,7 +430,7 @@ func TestExistsInRouteTable(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface {
|
||||
func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, ipAddress6CIDR string, listenPort int) *iface.WGIface {
|
||||
t.Helper()
|
||||
|
||||
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||
@@ -345,7 +439,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
|
||||
newNet, err := stdnet.NewNet()
|
||||
require.NoError(t, err)
|
||||
|
||||
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, ipAddress6CIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
require.NoError(t, err, "should create testing WireGuard interface")
|
||||
|
||||
err = wgInterface.Create()
|
||||
@@ -358,52 +452,74 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
|
||||
return wgInterface
|
||||
}
|
||||
|
||||
func setupRouteAndCleanup(t *testing.T, r *SysOps, prefix netip.Prefix, intf *net.Interface) {
|
||||
t.Helper()
|
||||
|
||||
err := r.AddVPNRoute(prefix, intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = r.RemoveVPNRoute(prefix, intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
}
|
||||
|
||||
func setupTestEnv(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
setupDummyInterfacesAndRoutes(t)
|
||||
|
||||
wgInterface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820)
|
||||
v6Addr := ""
|
||||
hasV6DefaultRoute, err := EnvironmentHasIPv6DefaultRoute()
|
||||
//goland:noinspection GoBoolExpressions
|
||||
if !iface.SupportsIPv6() || !firewall.SupportsIPv6() || !hasV6DefaultRoute || err != nil {
|
||||
t.Skip("Platform does not support IPv6, skipping IPv6 test...")
|
||||
} else {
|
||||
v6Addr = "2001:db8::4242:4711/128"
|
||||
}
|
||||
|
||||
wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", v6Addr, 51820)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, wgInterface.Close())
|
||||
assert.NoError(t, wgIface.Close())
|
||||
})
|
||||
|
||||
r := NewSysOps(wgInterface)
|
||||
_, _, err := r.SetupRouting(nil)
|
||||
_, _, err = setupRouting(nil, wgIface)
|
||||
require.NoError(t, err, "setupRouting should not return err")
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, r.CleanupRouting())
|
||||
assert.NoError(t, cleanupRouting())
|
||||
})
|
||||
|
||||
index, err := net.InterfaceByName(wgInterface.Name())
|
||||
index, err := net.InterfaceByName(wgIface.Name())
|
||||
require.NoError(t, err, "InterfaceByName should not return err")
|
||||
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
|
||||
intf := &net.Interface{Index: index.Index, Name: wgIface.Name()}
|
||||
|
||||
// default route exists in main table and vpn table
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("0.0.0.0/0"), intf)
|
||||
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
|
||||
// 10.0.0.0/8 route exists in main table and vpn table
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.0.0.0/8"), intf)
|
||||
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
|
||||
// 10.10.0.0/24 more specific route exists in vpn table
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.10.0.0/24"), intf)
|
||||
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
|
||||
// 127.0.10.0/24 more specific route exists in vpn table
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("127.0.10.0/24"), intf)
|
||||
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
|
||||
// unique route in vpn table
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
|
||||
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
}
|
||||
|
||||
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
|
||||
@@ -412,133 +528,17 @@ func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIf
|
||||
return
|
||||
}
|
||||
|
||||
prefixNexthop, err := GetNextHop(prefix.Addr())
|
||||
prefixGateway, _, err := GetNextHop(prefix.Addr())
|
||||
require.NoError(t, err, "GetNextHop should not return err")
|
||||
|
||||
nexthop := wgIface.Address().IP.String()
|
||||
if prefix.Addr().Is6() {
|
||||
nexthop = wgIface.Address6().IP.String()
|
||||
}
|
||||
|
||||
if invert {
|
||||
assert.NotEqual(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should not point to wireguard interface IP")
|
||||
assert.NotEqual(t, nexthop, prefixGateway.String(), "route should not point to wireguard interface IP")
|
||||
} else {
|
||||
assert.Equal(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should point to wireguard interface IP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsVpnRoute(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr string
|
||||
vpnRoutes []string
|
||||
localRoutes []string
|
||||
expectedVpn bool
|
||||
expectedPrefix netip.Prefix
|
||||
}{
|
||||
{
|
||||
name: "Match in VPN routes",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Match in local routes",
|
||||
addr: "10.1.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
{
|
||||
name: "No match",
|
||||
addr: "172.16.0.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Default route ignored",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Default route matches but ignored",
|
||||
addr: "172.16.1.1",
|
||||
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match local",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16"},
|
||||
localRoutes: []string{"192.168.1.0/24"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match local multiple",
|
||||
addr: "192.168.0.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
||||
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match vpn",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"192.168.0.0/16"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match vpn multiple",
|
||||
addr: "192.168.0.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
||||
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"),
|
||||
},
|
||||
{
|
||||
name: "Duplicate prefix in both",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"192.168.1.0/24"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr, err := netip.ParseAddr(tt.addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse address %s: %v", tt.addr, err)
|
||||
}
|
||||
|
||||
var vpnRoutes, localRoutes []netip.Prefix
|
||||
for _, route := range tt.vpnRoutes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse VPN route %s: %v", route, err)
|
||||
}
|
||||
vpnRoutes = append(vpnRoutes, prefix)
|
||||
}
|
||||
|
||||
for _, route := range tt.localRoutes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse local route %s: %v", route, err)
|
||||
}
|
||||
localRoutes = append(localRoutes, prefix)
|
||||
}
|
||||
|
||||
isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes)
|
||||
assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value")
|
||||
assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix")
|
||||
})
|
||||
assert.Equal(t, nexthop, prefixGateway.String(), "route should point to wireguard interface IP")
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,10 @@
|
||||
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
|
||||
|
||||
package systemops
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -86,10 +85,6 @@ var testCases = []testCase{
|
||||
|
||||
func TestRouting(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
// todo resolve test execution on freebsd
|
||||
if runtime.GOOS == "freebsd" {
|
||||
t.Skip("skipping ", tc.name, " on freebsd")
|
||||
}
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
setupTestEnv(t)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build windows
|
||||
|
||||
package systemops
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -17,7 +17,8 @@ import (
|
||||
"github.com/yusufpapurcu/wmi"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
type MSFT_NetRoute struct {
|
||||
@@ -56,42 +57,14 @@ var prefixList []netip.Prefix
|
||||
var lastUpdate time.Time
|
||||
var mux = sync.Mutex{}
|
||||
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
return r.setupRefCounter(initAddresses)
|
||||
var routeManager *RouteManager
|
||||
|
||||
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting() error {
|
||||
return r.cleanupRefCounter()
|
||||
}
|
||||
|
||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
if nexthop.IP.Zone() != "" && nexthop.Intf == nil {
|
||||
zone, err := strconv.Atoi(nexthop.IP.Zone())
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid zone: %w", err)
|
||||
}
|
||||
nexthop.Intf = &net.Interface{Index: zone}
|
||||
}
|
||||
|
||||
return addRouteCmd(prefix, nexthop)
|
||||
}
|
||||
|
||||
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
args := []string{"delete", prefix.String()}
|
||||
if nexthop.IP.IsValid() {
|
||||
ip := nexthop.IP.WithZone("")
|
||||
args = append(args, ip.Unmap().String())
|
||||
}
|
||||
|
||||
routeCmd := uspfilter.GetSystem32Command("route")
|
||||
|
||||
out, err := exec.Command(routeCmd, args...).CombinedOutput()
|
||||
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove route: %w", err)
|
||||
}
|
||||
return nil
|
||||
func cleanupRouting() error {
|
||||
return cleanupRoutingWithRouteManager(routeManager)
|
||||
}
|
||||
|
||||
func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
@@ -120,7 +93,7 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
func GetRoutes() ([]Route, error) {
|
||||
var entries []MSFT_NetRoute
|
||||
|
||||
query := `SELECT DestinationPrefix, Nexthop, InterfaceIndex, InterfaceAlias, AddressFamily FROM MSFT_NetRoute`
|
||||
query := `SELECT DestinationPrefix, NextHop, InterfaceIndex, InterfaceAlias, AddressFamily FROM MSFT_NetRoute`
|
||||
if err := wmi.QueryNamespace(query, &entries, `ROOT\StandardCimv2`); err != nil {
|
||||
return nil, fmt.Errorf("get routes: %w", err)
|
||||
}
|
||||
@@ -145,10 +118,6 @@ func GetRoutes() ([]Route, error) {
|
||||
Index: int(entry.InterfaceIndex),
|
||||
Name: entry.InterfaceAlias,
|
||||
}
|
||||
|
||||
if nexthop.Is6() && (nexthop.IsLinkLocalUnicast() || nexthop.IsLinkLocalMulticast()) {
|
||||
nexthop = nexthop.WithZone(strconv.Itoa(int(entry.InterfaceIndex)))
|
||||
}
|
||||
}
|
||||
|
||||
routes = append(routes, Route{
|
||||
@@ -188,12 +157,11 @@ func GetNeighbors() ([]Neighbor, error) {
|
||||
return neighbors, nil
|
||||
}
|
||||
|
||||
func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
args := []string{"add", prefix.String()}
|
||||
|
||||
if nexthop.IP.IsValid() {
|
||||
ip := nexthop.IP.WithZone("")
|
||||
args = append(args, ip.Unmap().String())
|
||||
if nexthop.IsValid() {
|
||||
args = append(args, nexthop.Unmap().String())
|
||||
} else {
|
||||
addr := "0.0.0.0"
|
||||
if prefix.Addr().Is6() {
|
||||
@@ -202,8 +170,8 @@ func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
args = append(args, addr)
|
||||
}
|
||||
|
||||
if nexthop.Intf != nil {
|
||||
args = append(args, "if", strconv.Itoa(nexthop.Intf.Index))
|
||||
if intf != nil {
|
||||
args = append(args, "if", strconv.Itoa(intf.Index))
|
||||
}
|
||||
|
||||
routeCmd := uspfilter.GetSystem32Command("route")
|
||||
@@ -217,6 +185,37 @@ func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
if nexthop.Zone() != "" && intf == nil {
|
||||
zone, err := strconv.Atoi(nexthop.Zone())
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid zone: %w", err)
|
||||
}
|
||||
intf = &net.Interface{Index: zone}
|
||||
nexthop.WithZone("")
|
||||
}
|
||||
|
||||
return addRouteCmd(prefix, nexthop, intf)
|
||||
}
|
||||
|
||||
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ *net.Interface) error {
|
||||
args := []string{"delete", prefix.String()}
|
||||
if nexthop.IsValid() {
|
||||
nexthop.WithZone("")
|
||||
args = append(args, nexthop.Unmap().String())
|
||||
}
|
||||
|
||||
routeCmd := uspfilter.GetSystem32Command("route")
|
||||
|
||||
out, err := exec.Command(routeCmd, args...).CombinedOutput()
|
||||
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove route: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isCacheDisabled() bool {
|
||||
return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true"
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package systemops
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -29,7 +29,7 @@ type FindNetRouteOutput struct {
|
||||
InterfaceIndex int `json:"InterfaceIndex"`
|
||||
InterfaceAlias string `json:"InterfaceAlias"`
|
||||
AddressFamily int `json:"AddressFamily"`
|
||||
NextHop string `json:"Nexthop"`
|
||||
NextHop string `json:"NextHop"`
|
||||
DestinationPrefix string `json:"DestinationPrefix"`
|
||||
}
|
||||
|
||||
@@ -166,7 +166,7 @@ func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOut
|
||||
host, _, err := net.SplitHostPort(destination)
|
||||
require.NoError(t, err)
|
||||
|
||||
script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, Nexthop, DestinationPrefix | ConvertTo-Json`, host)
|
||||
script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, NextHop, DestinationPrefix | ConvertTo-Json`, host)
|
||||
|
||||
out, err := exec.Command("powershell", "-Command", script).Output()
|
||||
require.NoError(t, err, "Failed to execute Find-NetRoute")
|
||||
@@ -207,7 +207,7 @@ func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR str
|
||||
}
|
||||
|
||||
func fetchOriginalGateway() (*RouteInfo, error) {
|
||||
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object Nexthop, RouteMetric, InterfaceAlias | ConvertTo-Json")
|
||||
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err)
|
||||
@@ -1,29 +0,0 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// GetPrefixFromIP returns a netip.Prefix from a net.IP address.
|
||||
func GetPrefixFromIP(ip net.IP) (netip.Prefix, error) {
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return netip.Prefix{}, fmt.Errorf("parse IP address: %s", ip)
|
||||
}
|
||||
addr = addr.Unmap()
|
||||
|
||||
var prefixLength int
|
||||
switch {
|
||||
case addr.Is4():
|
||||
prefixLength = 32
|
||||
case addr.Is6():
|
||||
prefixLength = 128
|
||||
default:
|
||||
return netip.Prefix{}, fmt.Errorf("invalid IP address: %s", addr)
|
||||
}
|
||||
|
||||
prefix := netip.PrefixFrom(addr, prefixLength)
|
||||
return prefix, nil
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
package vars
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
const MinRangeBits = 7
|
||||
|
||||
var (
|
||||
ErrRouteNotFound = errors.New("route not found")
|
||||
ErrRouteNotAllowed = errors.New("route not allowed")
|
||||
|
||||
Defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
Defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
)
|
||||
@@ -3,11 +3,11 @@ package routeselector
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/errors"
|
||||
route "github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -30,10 +30,10 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
}
|
||||
|
||||
var err *multierror.Error
|
||||
var multiErr *multierror.Error
|
||||
for _, route := range routes {
|
||||
if !slices.Contains(allRoutes, route) {
|
||||
err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route))
|
||||
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -41,7 +41,11 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
|
||||
}
|
||||
rs.selectAll = false
|
||||
|
||||
return errors.FormatErrorOrNil(err)
|
||||
if multiErr != nil {
|
||||
multiErr.ErrorFormat = formatError
|
||||
}
|
||||
|
||||
return multiErr.ErrorOrNil()
|
||||
}
|
||||
|
||||
// SelectAllRoutes sets the selector to select all routes.
|
||||
@@ -61,17 +65,21 @@ func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.
|
||||
}
|
||||
}
|
||||
|
||||
var err *multierror.Error
|
||||
var multiErr *multierror.Error
|
||||
|
||||
for _, route := range routes {
|
||||
if !slices.Contains(allRoutes, route) {
|
||||
err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route))
|
||||
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
|
||||
continue
|
||||
}
|
||||
delete(rs.selectedRoutes, route)
|
||||
}
|
||||
|
||||
return errors.FormatErrorOrNil(err)
|
||||
if multiErr != nil {
|
||||
multiErr.ErrorFormat = formatError
|
||||
}
|
||||
|
||||
return multiErr.ErrorOrNil()
|
||||
}
|
||||
|
||||
// DeselectAllRoutes deselects all routes, effectively disabling route selection.
|
||||
@@ -103,3 +111,18 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func formatError(es []error) string {
|
||||
if len(es) == 1 {
|
||||
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
|
||||
}
|
||||
|
||||
points := make([]string, len(es))
|
||||
for i, err := range es {
|
||||
points[i] = fmt.Sprintf("* %s", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%d errors occurred:\n\t%s",
|
||||
len(es), strings.Join(points, "\n\t"))
|
||||
}
|
||||
|
||||
@@ -261,15 +261,15 @@ func TestRouteSelector_FilterSelected(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
routes := route.HAMap{
|
||||
"route1|10.0.0.0/8": {},
|
||||
"route2|192.168.0.0/16": {},
|
||||
"route3|172.16.0.0/12": {},
|
||||
"route1-10.0.0.0/8": {},
|
||||
"route2-192.168.0.0/16": {},
|
||||
"route3-172.16.0.0/12": {},
|
||||
}
|
||||
|
||||
filtered := rs.FilterSelected(routes)
|
||||
|
||||
assert.Equal(t, route.HAMap{
|
||||
"route1|10.0.0.0/8": {},
|
||||
"route2|192.168.0.0/16": {},
|
||||
"route1-10.0.0.0/8": {},
|
||||
"route2-192.168.0.0/16": {},
|
||||
}, filtered)
|
||||
}
|
||||
|
||||
@@ -8,13 +8,9 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
|
||||
func NewFactory(ctx context.Context, wgPort int) *Factory {
|
||||
f := &Factory{wgPort: wgPort}
|
||||
|
||||
if userspace {
|
||||
return f
|
||||
}
|
||||
|
||||
ebpfProxy := NewWGEBPFProxy(ctx, wgPort)
|
||||
err := ebpfProxy.listen()
|
||||
if err != nil {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user