Compare commits

..

1 Commits

Author SHA1 Message Date
Hugo Hakim Damer
8b0398c0db Add support for IPv6 networks (on Linux clients) (#1459)
* Feat add basic support for IPv6 networks

Newly generated networks automatically generate an IPv6 prefix of size
64 within the ULA address range, devices obtain a randomly generated
address within this prefix.

Currently, this is Linux only and does not yet support all features
(routes currently cause an error).

* Fix firewall configuration for IPv6 networks

* Fix routing configuration for IPv6 networks

* Feat provide info on IPv6 support for specific client to mgmt server

* Feat allow configuration of IPv6 support through API, improve stability

* Feat add IPv6 support to new firewall implementation

* Fix peer list item response not containing IPv6 address

* Fix nftables breaking on IPv6 address change

* Fix build issues for non-linux systems

* Fix intermittent disconnections when IPv6 is enabled

* Fix test issues and make some minor revisions

* Fix some more testing issues

* Fix more CI issues due to IPv6

* Fix more testing issues

* Add inheritance of IPv6 enablement status from groups

* Fix IPv6 events not having associated messages

* Address first review comments regarding IPv6 support

* Fix IPv6 table being created even when IPv6 is disabled

Also improved stability of IPv6 route and firewall handling on client side

* Fix IPv6 routes not being removed

* Fix DNS IPv6 issues, limit IPv6 nameservers to IPv6 peers

* Improve code for IPv6 DNS server selection, add AAAA custom records

* Ensure IPv6 routes can only exist for IPv6 routing peers

* Fix IPv6 network generation randomness

* Fix a bunch of compilation issues and test failures

* Replace method calls that are unavailable in Go 1.21

* Fix nil dereference in cleanUpDefaultForwardRules6

* Fix nil pointer dereference when persisting IPv6 network in sqlite

* Clean up of client-side code changes for IPv6

* Fix nil dereference in rule mangling and compilation issues

* Add a bunch of client-side test cases for IPv6

* Fix IPv6 tests running on unsupported environments

* Fix import cycle in tests

* Add missing method SupportsIPv6() for windows

* Require IPv6 default route for IPv6 tests

* Fix panics in routemanager tests on non-linux

* Fix some more route manager tests concerning IPv6

* Add some final client-side tests

* Add IPv6 tests for management code, small fixes

* Fix linting issues

* Fix small test suite issues

* Fix linter issues and builds on macOS and Windows again

* fix builds for iOS because of IPv6 breakage
2024-08-13 17:26:27 +02:00
310 changed files with 10300 additions and 12346 deletions

View File

@@ -14,7 +14,7 @@ jobs:
test:
strategy:
matrix:
store: ['sqlite']
store: ['jsonfile', 'sqlite']
runs-on: macos-latest
steps:
- name: Install Go

View File

@@ -1,39 +0,0 @@
name: Test Code FreeBSD
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Test in FreeBSD
id: test
uses: vmactions/freebsd-vm@v1
with:
usesh: true
prepare: |
pkg install -y curl
pkg install -y git
run: |
set -x
curl -o go.tar.gz https://go.dev/dl/go1.21.11.freebsd-amd64.tar.gz -L
tar zxf go.tar.gz
mv go /usr/local/go
ln -s /usr/local/go/bin/go /usr/local/bin/go
go mod tidy
go test -timeout 5m -p 1 ./iface/...
go test -timeout 5m -p 1 ./client/...
cd client
go build .
cd ..

View File

@@ -15,7 +15,7 @@ jobs:
strategy:
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres']
store: [ 'jsonfile', 'sqlite', 'postgres']
runs-on: ubuntu-latest
steps:
- name: Install Go
@@ -86,10 +86,7 @@ jobs:
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
- name: Generate RouteManager Test bin
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager
- name: Generate SystemOps Test bin
run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops
run: CGO_ENABLED=1 go test -c -o routemanager-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/...
- name: Generate nftables Manager Test bin
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
@@ -111,9 +108,6 @@ jobs:
- name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
- name: Run SystemOps tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1
- name: Run nftables Manager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -173,7 +173,7 @@ jobs:
retention-days: 3
release_ui_darwin:
runs-on: macos-latest
runs-on: macos-11
steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV

View File

@@ -178,79 +178,34 @@ jobs:
- name: Checkout code
uses: actions/checkout@v3
- name: run script with Zitadel PostgreSQL
- name: run script
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
- name: test Caddy file gen postgres
- name: test Caddy file gen
run: test -f Caddyfile
- name: test docker-compose file gen postgres
- name: test docker-compose file gen
run: test -f docker-compose.yml
- name: test management.json file gen postgres
- name: test management.json file gen
run: test -f management.json
- name: test turnserver.conf file gen postgres
- name: test turnserver.conf file gen
run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
- name: test zitadel.env file gen postgres
- name: test zitadel.env file gen
run: test -f zitadel.env
- name: test dashboard.env file gen postgres
- name: test dashboard.env file gen
run: test -f dashboard.env
- name: test zdb.env file gen postgres
run: test -f zdb.env
- name: Postgres run cleanup
run: |
docker-compose down --volumes --rmi all
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
- name: run script with Zitadel CockroachDB
run: bash -x infrastructure_files/getting-started-with-zitadel.sh
env:
NETBIRD_DOMAIN: use-ip
ZITADEL_DATABASE: cockroach
- name: test Caddy file gen CockroachDB
run: test -f Caddyfile
- name: test docker-compose file gen CockroachDB
run: test -f docker-compose.yml
- name: test management.json file gen CockroachDB
run: test -f management.json
- name: test turnserver.conf file gen CockroachDB
run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
- name: test zitadel.env file gen CockroachDB
run: test -f zitadel.env
- name: test dashboard.env file gen CockroachDB
run: test -f dashboard.env
test-download-geolite2-script:
runs-on: ubuntu-latest
steps:
- name: Install jq
run: sudo apt-get update && sudo apt-get install -y unzip sqlite3
- name: Checkout code
uses: actions/checkout@v3
- name: test script
run: bash -x infrastructure_files/download-geolite2.sh
- name: test mmdb file exists
run: test -f GeoLite2-City.mmdb
- name: test geonames file exists
run: test -f geonames.db

View File

@@ -3,10 +3,8 @@ builds:
- id: netbird-ui-darwin
dir: client/ui
binary: netbird-ui
env:
- CGO_ENABLED=1
- MACOSX_DEPLOYMENT_TARGET=11.0
- MACOS_DEPLOYMENT_TARGET=11.0
env: [CGO_ENABLED=1]
goos:
- darwin
goarch:

View File

@@ -1,4 +1,4 @@
FROM alpine:3.19
FROM alpine:3.18.5
RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"]

View File

@@ -59,7 +59,7 @@ var forCmd = &cobra.Command{
}
func debugBundle(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
@@ -80,7 +80,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
}
func setLogLevel(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
@@ -109,7 +109,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("invalid duration format: %v", err)
}
conn, err := getClient(cmd)
conn, err := getClient(cmd.Context())
if err != nil {
return err
}

View File

@@ -36,7 +36,6 @@ const (
disableAutoConnectFlag = "disable-auto-connect"
serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
)
var (
@@ -69,9 +68,7 @@ var (
autoConnectDisabled bool
extraIFaceBlackList []string
anonymizeFlag bool
dnsRouteInterval time.Duration
rootCmd = &cobra.Command{
rootCmd = &cobra.Command{
Use: "netbird",
Short: "",
Long: "",
@@ -356,11 +353,8 @@ func migrateToNetbird(oldPath, newPath string) bool {
return true
}
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
func getClient(ctx context.Context) (*grpc.ClientConn, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+

View File

@@ -2,7 +2,6 @@ package cmd
import (
"fmt"
"strings"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
@@ -50,7 +49,7 @@ func init() {
}
func routesList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
@@ -67,62 +66,20 @@ func routesList(cmd *cobra.Command, _ []string) error {
return nil
}
printRoutes(cmd, resp)
cmd.Println("Available Routes:")
for _, route := range resp.Routes {
selectedStatus := "Not Selected"
if route.GetSelected() {
selectedStatus = "Selected"
}
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
}
return nil
}
func printRoutes(cmd *cobra.Command, resp *proto.ListRoutesResponse) {
cmd.Println("Available Routes:")
for _, route := range resp.Routes {
printRoute(cmd, route)
}
}
func printRoute(cmd *cobra.Command, route *proto.Route) {
selectedStatus := getSelectedStatus(route)
domains := route.GetDomains()
if len(domains) > 0 {
printDomainRoute(cmd, route, domains, selectedStatus)
} else {
printNetworkRoute(cmd, route, selectedStatus)
}
}
func getSelectedStatus(route *proto.Route) string {
if route.GetSelected() {
return "Selected"
}
return "Not Selected"
}
func printDomainRoute(cmd *cobra.Command, route *proto.Route, domains []string, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
resolvedIPs := route.GetResolvedIPs()
if len(resolvedIPs) > 0 {
printResolvedIPs(cmd, domains, resolvedIPs)
} else {
cmd.Printf(" Resolved IPs: -\n")
}
}
func printNetworkRoute(cmd *cobra.Command, route *proto.Route, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
}
func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) {
cmd.Printf(" Resolved IPs:\n")
for _, domain := range domains {
if ipList, exists := resolvedIPs[domain]; exists {
cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", "))
}
}
}
func routesSelect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
@@ -149,7 +106,7 @@ func routesSelect(cmd *cobra.Command, args []string) error {
}
func routesDeselect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
conn, err := getClient(cmd.Context())
if err != nil {
return err
}

View File

@@ -807,7 +807,11 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
}
for i, route := range peer.Routes {
peer.Routes[i] = anonymizeRoute(a, route)
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
peer.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
}
}
@@ -843,21 +847,12 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview)
}
for i, route := range overview.Routes {
overview.Routes[i] = anonymizeRoute(a, route)
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
overview.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
}
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
}
func anonymizeRoute(a *anonymize.Anonymizer, route string) string {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
domains := strings.Split(route, ", ")
for i, domain := range domains {
domains[i] = a.AnonymizeDomain(domain)
}
return strings.Join(domains, ", ")
}

View File

@@ -7,9 +7,6 @@ import (
"testing"
"time"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/util"
@@ -56,10 +53,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err)
}
s := grpc.NewServer()
srv, err := sig.NewServer(otel.Meter(""))
require.NoError(t, err)
sigProto.RegisterSignalExchangeServer(s, srv)
sigProto.RegisterSignalExchangeServer(s, sig.NewServer())
go func() {
if err := s.Serve(lis); err != nil {
panic(err)
@@ -76,7 +70,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
t.Fatal(err)
}
s := grpc.NewServer()
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
store, cleanUp, err := mgmt.NewTestStoreFromJson(config.Datadir)
if err != nil {
t.Fatal(err)
}
@@ -87,13 +81,13 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
if err != nil {
return nil, nil
}
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
iv, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -108,7 +102,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
}
func startClientDaemon(
t *testing.T, ctx context.Context, _, configPath string,
t *testing.T, ctx context.Context, managementURL, configPath string,
) (*grpc.Server, net.Listener) {
t.Helper()
lis, err := net.Listen("tcp", "127.0.0.1:0")

View File

@@ -7,13 +7,11 @@ import (
"net/netip"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
@@ -42,12 +40,8 @@ func init() {
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
)
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", false, "Enable network monitoring")
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
}
func upFunc(cmd *cobra.Command, args []string) error {
@@ -143,10 +137,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
}
}
if cmd.Flag(dnsRouteIntervalFlag).Changed {
ic.DNSRouteInterval = &dnsRouteInterval
}
config, err := internal.UpdateOrCreateConfig(ic)
if err != nil {
return fmt.Errorf("get config file: %v", err)
@@ -247,10 +237,6 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
loginRequest.NetworkMonitor = &networkMonitor
}
if cmd.Flag(dnsRouteIntervalFlag).Changed {
loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval)
}
var loginErr error
var loginResp *proto.LoginResponse

View File

@@ -1,30 +0,0 @@
package errors
import (
"fmt"
"strings"
"github.com/hashicorp/go-multierror"
)
func formatError(es []error) string {
if len(es) == 0 {
return fmt.Sprintf("0 error occurred:\n\t* %s", es[0])
}
points := make([]string, len(es))
for i, err := range es {
points[i] = fmt.Sprintf("* %s", err)
}
return fmt.Sprintf(
"%d errors occurred:\n\t%s",
len(es), strings.Join(points, "\n\t"))
}
func FormatErrorOrNil(err *multierror.Error) error {
if err != nil {
err.ErrorFormat = formatError
}
return err.ErrorOrNil()
}

View File

@@ -30,3 +30,9 @@ func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager,
}
return fm, nil
}
// Returns true if the current firewall implementation supports IPv6.
// Currently false for anything non-linux.
func SupportsIPv6() bool {
return false
}

View File

@@ -70,6 +70,8 @@ func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager,
return nil, errUsp
}
// Note for devs: When adding IPv6 support to userspace bind, the implementation of AllowNetbird() has to be
// adjusted accordingly.
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
@@ -83,6 +85,12 @@ func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager,
return fm, nil
}
// Returns true if the current firewall implementation supports IPv6.
// Currently true if the firewall is nftables.
func SupportsIPv6() bool {
return check() == NFTABLES
}
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func check() FWType {
useIPTABLES := false

View File

@@ -6,6 +6,7 @@ import "github.com/netbirdio/netbird/iface"
type IFaceMapper interface {
Name() string
Address() iface.WGAddress
Address6() *iface.WGAddress
IsUserspaceBind() bool
SetFilter(iface.PacketFilter) error
}

View File

@@ -24,6 +24,14 @@ type Manager struct {
router *routerManager
}
func (m *Manager) ResetV6Firewall() error {
return nil
}
func (m *Manager) V6Active() bool {
return false
}
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string

View File

@@ -74,12 +74,12 @@ func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
return nil
}
err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
err = i.insertRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
if err != nil {
return err
}
err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
err = i.insertRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
if err != nil {
return err
}
@@ -101,7 +101,6 @@ func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string,
}
delete(i.rules, ruleKey)
}
err = i.iptablesClient.Insert(table, chain, 1, rule...)
if err != nil {
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
@@ -318,13 +317,6 @@ func (i *routerManager) createChain(table, newChain string) error {
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
}
// Add the loopback return rule to the NAT chain
loopbackRule := []string{"-o", "lo", "-j", "RETURN"}
err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...)
if err != nil {
return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err)
}
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
if err != nil {
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
@@ -334,30 +326,6 @@ func (i *routerManager) createChain(table, newChain string) error {
return nil
}
// addNATRule appends an iptables rule pair to the nat chain
func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, pair.Source, pair.Destination)
existingRule, found := i.rules[ruleKey]
if found {
err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
}
delete(i.rules, ruleKey)
}
// inserting after loopback ignore rule
err := i.iptablesClient.Insert(table, chain, 2, rule...)
if err != nil {
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
}
i.rules[ruleKey] = rule
return nil
}
// genRuleSpec generates rule specification
func genRuleSpec(jump, source, destination string) []string {
return []string{"-s", source, "-d", destination, "-j", jump}

View File

@@ -73,6 +73,9 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
if testCase.IsV6 {
t.Skip("Environment does not support IPv6, skipping IPv6 test...")
}
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
@@ -154,6 +157,9 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
if testCase.IsV6 {
t.Skip("Environment does not support IPv6, skipping IPv6 test...")
}
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
manager, err := newRouterManager(context.TODO(), iptablesClient)

View File

@@ -76,6 +76,13 @@ type Manager interface {
// RemoveRoutingRules removes a routing firewall rule
RemoveRoutingRules(pair RouterPair) error
// ResetV6Firewall makes changes to the firewall to adapt to the IP address changes.
// It is expected that after calling this method ApplyFiltering will be called to re-add the firewall rules.
ResetV6Firewall() error
// V6Active returns whether IPv6 rules should/may be created by upper layers.
V6Active() bool
// Reset firewall to the default state
Reset() error

File diff suppressed because it is too large Load Diff

View File

@@ -36,17 +36,25 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
wgIface: wgIface,
}
workTable, err := m.createWorkTable()
workTable, err := m.createWorkTable(nftables.TableFamilyIPv4)
if err != nil {
return nil, err
}
m.router, err = newRouter(context, workTable)
var workTable6 *nftables.Table
if wgIface.Address6() != nil {
workTable6, err = m.createWorkTable(nftables.TableFamilyIPv6)
if err != nil {
return nil, err
}
}
m.router, err = newRouter(context, workTable, workTable6)
if err != nil {
return nil, err
}
m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName())
m.aclManager, err = newAclManager(workTable, workTable6, wgIface, m.router.RouteingFwChainName())
if err != nil {
return nil, err
}
@@ -54,6 +62,54 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
return m, nil
}
// Resets the IPv6 Firewall Table to adapt to changes in IP addresses
func (m *Manager) ResetV6Firewall() error {
// First, prepare reset by deleting all currently active rules.
workTable6, err := m.aclManager.PrepareV6Reset()
if err != nil {
return err
}
// Depending on whether we now have an IPv6 address, we now either have to create/empty an IPv6 table, or delete it.
if m.wgIface.Address6() != nil {
if workTable6 != nil {
m.rConn.FlushTable(workTable6)
} else {
workTable6, err = m.createWorkTable(nftables.TableFamilyIPv6)
if err != nil {
return err
}
}
} else {
m.rConn.DelTable(workTable6)
workTable6 = nil
}
err = m.rConn.Flush()
if err != nil {
return err
}
// Restore routing rules.
err = m.router.RestoreAfterV6Reset(workTable6)
if err != nil {
return err
}
// Restore basic firewall chains (needs to happen after routes because chains from router must exist).
// Does not restore rules (will be done later during the update, when UpdateFiltering will be called at some point)
err = m.aclManager.ReinitAfterV6Reset(workTable6)
if err != nil {
return err
}
return m.rConn.Flush()
}
func (m *Manager) V6Active() bool {
return m.aclManager.v6Active
}
// AddFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
@@ -72,7 +128,7 @@ func (m *Manager) AddFiltering(
defer m.mutex.Unlock()
rawIP := ip.To4()
if rawIP == nil {
if rawIP == nil && m.wgIface.Address6() == nil {
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
}
@@ -95,7 +151,7 @@ func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddRoutingRules(pair)
return m.router.InsertRoutingRules(pair)
}
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
@@ -114,6 +170,8 @@ func (m *Manager) AllowNetbird() error {
m.mutex.Lock()
defer m.mutex.Unlock()
// Note for devs: When adding IPv6 support to uspfilter, the implementation of createDefaultAllowRules()
// must be adjusted to include IPv6 rules.
err := m.aclManager.createDefaultAllowRules()
if err != nil {
return fmt.Errorf("failed to create default allow rules: %v", err)
@@ -211,8 +269,8 @@ func (m *Manager) Flush() error {
return m.aclManager.Flush()
}
func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
func (m *Manager) createWorkTable(tableFamily nftables.TableFamily) (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(tableFamily)
if err != nil {
return nil, fmt.Errorf("list of tables: %w", err)
}
@@ -223,7 +281,7 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
}
}
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: tableFamily})
err = m.rConn.Flush()
return table, err
}

View File

@@ -19,8 +19,9 @@ import (
// iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct {
NameFunc func() string
AddressFunc func() iface.WGAddress
NameFunc func() string
AddressFunc func() iface.WGAddress
Address6Func func() *iface.WGAddress
}
func (i *iFaceMock) Name() string {
@@ -37,6 +38,13 @@ func (i *iFaceMock) Address() iface.WGAddress {
panic("AddressFunc is not set")
}
func (i *iFaceMock) Address6() *iface.WGAddress {
if i.Address6Func != nil {
return i.Address6Func()
}
panic("AddressFunc is not set")
}
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) {
@@ -53,6 +61,7 @@ func TestNftablesManager(t *testing.T) {
},
}
},
Address6Func: func() *iface.WGAddress { return nil },
}
// just check on the local interface
@@ -99,11 +108,9 @@ func TestNftablesManager(t *testing.T) {
Register: 1,
Data: ifname("lo"),
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: uint32(9),
Len: uint32(1),
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Cmp{
Register: 1,
@@ -152,6 +159,370 @@ func TestNftablesManager(t *testing.T) {
require.NoError(t, err, "failed to reset")
}
func TestNftablesManager6Disabled(t *testing.T) {
mock := &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
Address6Func: func() *iface.WGAddress { return nil },
}
// just check on the local interface
manager, err := Create(context.Background(), mock)
require.NoError(t, err)
time.Sleep(time.Second * 3)
defer func() {
err = manager.Reset()
require.NoError(t, err, "failed to reset")
time.Sleep(time.Second)
}()
ip := net.ParseIP("2001:db8::fedc:ba09:8765:4321")
testClient := &nftables.Conn{}
_, err = manager.AddFiltering(
ip,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{53}},
fw.RuleDirectionIN,
fw.ActionDrop,
"",
"",
)
require.Error(t, err, "IPv6 rule should not be added when IPv6 is disabled")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules")
require.Len(t, rules, 0, "expected no rules")
err = manager.Reset()
require.NoError(t, err, "failed to reset")
}
func TestNftablesManager6(t *testing.T) {
if !iface.SupportsIPv6() {
t.Skip("Environment does not support IPv6, skipping IPv6 test...")
}
mock := &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
Address6Func: func() *iface.WGAddress {
return &iface.WGAddress{
IP: net.ParseIP("2001:db8::0123:4567:890a:bcde"),
Network: &net.IPNet{
IP: net.ParseIP("2001:db8::"),
Mask: net.CIDRMask(64, 128),
},
}
},
}
// just check on the local interface
manager, err := Create(context.Background(), mock)
require.NoError(t, err)
time.Sleep(time.Second * 3)
defer func() {
err = manager.Reset()
require.NoError(t, err, "failed to reset")
time.Sleep(time.Second)
}()
require.True(t, manager.V6Active(), "IPv6 is not active even though it should be.")
ip := net.ParseIP("2001:db8::fedc:ba09:8765:4321")
testClient := &nftables.Conn{}
rule, err := manager.AddFiltering(
ip,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{53}},
fw.RuleDirectionIN,
fw.ActionDrop,
"",
"",
)
require.NoError(t, err, "failed to add rule")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err := testClient.GetRules(manager.aclManager.workTable6, manager.aclManager.chainInputRules6)
require.NoError(t, err, "failed to get rules")
require.Len(t, rules, 1, "expected 1 rules")
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
expectedExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname("lo"),
},
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Cmp{
Register: 1,
Op: expr.CmpOpEq,
Data: []byte{unix.IPPROTO_TCP},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 8,
Len: 16,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: add.AsSlice(),
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{0, 53},
},
&expr.Verdict{Kind: expr.VerdictDrop},
}
require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
for _, r := range rule {
err = manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule")
}
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err = testClient.GetRules(manager.aclManager.workTable6, manager.aclManager.chainInputRules6)
require.NoError(t, err, "failed to get rules")
require.Len(t, rules, 0, "expected 0 rules after deletion")
err = manager.Reset()
require.NoError(t, err, "failed to reset")
}
func TestNftablesManagerAddressReset6(t *testing.T) {
if !iface.SupportsIPv6() {
t.Skip("Environment does not support IPv6, skipping IPv6 test...")
}
mock := &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
Address6Func: func() *iface.WGAddress {
return &iface.WGAddress{
IP: net.ParseIP("2001:db8::0123:4567:890a:bcde"),
Network: &net.IPNet{
IP: net.ParseIP("2001:db8::"),
Mask: net.CIDRMask(64, 128),
},
}
},
}
// just check on the local interface
manager, err := Create(context.Background(), mock)
require.NoError(t, err)
time.Sleep(time.Second * 3)
defer func() {
err = manager.Reset()
require.NoError(t, err, "failed to reset")
time.Sleep(time.Second)
}()
require.True(t, manager.V6Active(), "IPv6 is not active even though it should be.")
ip := net.ParseIP("2001:db8::fedc:ba09:8765:4321")
testClient := &nftables.Conn{}
_, err = manager.AddFiltering(
ip,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{53}},
fw.RuleDirectionIN,
fw.ActionDrop,
"",
"",
)
require.NoError(t, err, "failed to add rule")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err := testClient.GetRules(manager.aclManager.workTable6, manager.aclManager.chainInputRules6)
require.NoError(t, err, "failed to get rules")
require.Len(t, rules, 1, "expected 1 rules")
mock.Address6Func = func() *iface.WGAddress {
return nil
}
err = manager.ResetV6Firewall()
require.NoError(t, err, "failed to reset IPv6 firewall")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
require.False(t, manager.V6Active(), "IPv6 is active even though it shouldn't be.")
tables, err := testClient.ListTablesOfFamily(nftables.TableFamilyIPv6)
require.NoError(t, err, "failed to list IPv6 tables")
for _, table := range tables {
if table.Name == tableName {
t.Errorf("When IPv6 is disabled, the netbird table should not exist.")
}
}
mock.Address6Func = func() *iface.WGAddress {
return &iface.WGAddress{
IP: net.ParseIP("2001:db8::0123:4567:890a:bcdf"),
Network: &net.IPNet{
IP: net.ParseIP("2001:db8::"),
Mask: net.CIDRMask(64, 128),
},
}
}
err = manager.ResetV6Firewall()
require.NoError(t, err, "failed to reset IPv6 firewall")
require.True(t, manager.V6Active(), "IPv6 is not active even though it should be.")
rule, err := manager.AddFiltering(
ip,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{53}},
fw.RuleDirectionIN,
fw.ActionDrop,
"",
"",
)
require.NoError(t, err, "failed to add rule")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err = testClient.GetRules(manager.aclManager.workTable6, manager.aclManager.chainInputRules6)
require.NoError(t, err, "failed to get rules")
require.Len(t, rules, 1, "expected 1 rule")
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
expectedExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname("lo"),
},
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Cmp{
Register: 1,
Op: expr.CmpOpEq,
Data: []byte{unix.IPPROTO_TCP},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 8,
Len: 16,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: add.AsSlice(),
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{0, 53},
},
&expr.Verdict{Kind: expr.VerdictDrop},
}
require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
for _, r := range rule {
err = manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule")
}
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err = testClient.GetRules(manager.aclManager.workTable6, manager.aclManager.chainInputRules6)
require.NoError(t, err, "failed to get rules")
require.Len(t, rules, 0, "expected 0 rules after deletion")
err = manager.Reset()
require.NoError(t, err, "failed to reset")
}
func TestNFtablesCreatePerformance(t *testing.T) {
mock := &iFaceMock{
NameFunc: func() string {
@@ -166,6 +537,16 @@ func TestNFtablesCreatePerformance(t *testing.T) {
},
}
},
Address6Func: func() *iface.WGAddress {
v6addr, v6net, _ := net.ParseCIDR("fd00:1234:dead:beef::1/64")
return &iface.WGAddress{
IP: v6addr,
Network: &net.IPNet{
IP: v6net.IP,
Mask: v6net.Mask,
},
}
},
}
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {

View File

@@ -22,13 +22,12 @@ const (
userDataAcceptForwardRuleSrc = "frwacceptsrc"
userDataAcceptForwardRuleDst = "frwacceptdst"
loopbackInterface = "lo\x00"
)
// some presets for building nftable rules
var (
zeroXor = binaryutil.NativeEndian.PutUint32(0)
zeroXor = binaryutil.NativeEndian.PutUint32(0)
zeroXor6 = []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}
exprCounterAccept = []expr.Any{
&expr.Counter{},
@@ -41,48 +40,69 @@ var (
)
type router struct {
ctx context.Context
stop context.CancelFunc
conn *nftables.Conn
workTable *nftables.Table
filterTable *nftables.Table
chains map[string]*nftables.Chain
ctx context.Context
stop context.CancelFunc
conn *nftables.Conn
workTable *nftables.Table
workTable6 *nftables.Table
filterTable *nftables.Table
filterTable6 *nftables.Table
chains map[string]*nftables.Chain
chains6 map[string]*nftables.Chain
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
rules map[string]*nftables.Rule
isDefaultFwdRulesEnabled bool
rules map[string]*nftables.Rule
rules6 map[string]*nftables.Rule
isDefaultFwdRulesEnabled bool
isDefaultFwdRulesEnabled6 bool
}
func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) {
func newRouter(parentCtx context.Context, workTable *nftables.Table, workTable6 *nftables.Table) (*router, error) {
ctx, cancel := context.WithCancel(parentCtx)
r := &router{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
workTable: workTable,
workTable6: workTable6,
chains: make(map[string]*nftables.Chain),
chains6: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
rules6: make(map[string]*nftables.Rule),
}
var err error
r.filterTable, err = r.loadFilterTable()
r.filterTable, r.filterTable6, err = r.loadFilterTables()
if err != nil {
if errors.Is(err, errFilterTableNotFound) {
log.Warnf("table 'filter' not found for forward rules")
log.Warnf("table 'filter' not found for forward rules for one of the supported address families-")
} else {
return nil, err
}
}
err = r.cleanUpDefaultForwardRules()
err = r.cleanUpDefaultForwardRules(false)
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.createContainers()
err = r.cleanUpDefaultForwardRules(true)
if err != nil {
log.Errorf("failed to clean up rules from IPv6 FORWARD chain: %s", err)
}
err = r.createContainers(false)
if err != nil {
log.Errorf("failed to create containers for route: %s", err)
}
if r.workTable6 != nil {
err = r.createContainers(true)
if err != nil {
log.Errorf("failed to create v6 containers for route: %s", err)
}
}
return r, err
}
@@ -92,59 +112,99 @@ func (r *router) RouteingFwChainName() string {
// ResetForwardRules cleans existing nftables default forward rules from the system
func (r *router) ResetForwardRules() {
err := r.cleanUpDefaultForwardRules()
err := r.cleanUpDefaultForwardRules(false)
if err != nil {
log.Errorf("failed to reset forward rules: %s", err)
}
err = r.cleanUpDefaultForwardRules(true)
if err != nil {
log.Errorf("failed to reset forward rules: %s", err)
}
}
func (r *router) loadFilterTable() (*nftables.Table, error) {
func (r *router) RestoreAfterV6Reset(newWorktable6 *nftables.Table) error {
r.workTable6 = newWorktable6
if newWorktable6 != nil {
err := r.cleanUpDefaultForwardRules(true)
if err != nil {
log.Errorf("failed to clean up rules from IPv6 FORWARD chain: %s", err)
}
err = r.createContainers(true)
if err != nil {
return err
}
for name, rule := range r.rules6 {
rule = &nftables.Rule{
Table: r.workTable6,
Chain: r.chains6[rule.Chain.Name],
Exprs: rule.Exprs,
UserData: rule.UserData,
}
r.rules6[name] = r.conn.AddRule(rule)
}
}
return r.conn.Flush()
}
func (r *router) loadFilterTables() (*nftables.Table, *nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
return nil, nil, fmt.Errorf("nftables: unable to list tables: %v", err)
}
var table4 *nftables.Table = nil
for _, table := range tables {
if table.Name == "filter" {
return table, nil
table4 = table
break
}
}
return nil, errFilterTableNotFound
var table6 *nftables.Table = nil
tables, err = r.conn.ListTablesOfFamily(nftables.TableFamilyIPv6)
if err != nil {
return nil, nil, fmt.Errorf("nftables: unable to list tables: %v", err)
}
for _, table := range tables {
if table.Name == "filter" {
table6 = table
break
}
}
err = nil
if table4 == nil || table6 == nil {
err = errFilterTableNotFound
}
return table4, table6, err
}
func (r *router) createContainers() error {
func (r *router) createContainers(forV6 bool) error {
workTable := r.workTable
chainStorage := r.chains
if forV6 {
workTable = r.workTable6
chainStorage = r.chains6
}
r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{
chainStorage[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRouteingFw,
Table: r.workTable,
Table: workTable,
})
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
chainStorage[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Table: r.workTable,
Table: workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityNATSource - 1,
Type: nftables.ChainTypeNAT,
})
// Add RETURN rule for loopback interface
loRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(loopbackInterface),
},
&expr.Verdict{Kind: expr.VerdictReturn},
},
}
r.conn.InsertRule(loRule)
err := r.refreshRulesMap()
err := r.refreshRulesMap(forV6)
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
@@ -156,34 +216,44 @@ func (r *router) createContainers() error {
return nil
}
// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (r *router) AddRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap()
// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
parsedIp, _, _ := net.ParseCIDR(pair.Source)
if parsedIp.To4() == nil && r.workTable6 == nil {
return fmt.Errorf("nftables: attempted to add IPv6 routing rule even though IPv6 is not enabled for this host")
}
err := r.refreshRulesMap(parsedIp.To4() == nil)
if err != nil {
return err
}
err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
err = r.insertRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
if err != nil {
return err
}
err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
err = r.insertRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
if err != nil {
return err
}
if pair.Masquerade {
err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
err = r.insertRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
if err != nil {
return err
}
err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
err = r.insertRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
if err != nil {
return err
}
}
if r.filterTable != nil && !r.isDefaultFwdRulesEnabled {
filterTable := r.filterTable
if parsedIp.To4() == nil {
filterTable = r.filterTable6
}
if filterTable != nil && !r.isDefaultFwdRulesEnabled {
log.Debugf("add default accept forward rule")
r.acceptForwardRule(pair.Source)
}
@@ -195,8 +265,8 @@ func (r *router) AddRoutingRules(pair manager.RouterPair) error {
return nil
}
// addRoutingRule inserts a nftable rule to the conn client flush queue
func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
// insertRoutingRule inserts a nftable rule to the conn client flush queue
func (r *router) insertRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
@@ -209,7 +279,13 @@ func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPai
ruleKey := manager.GenKey(format, pair.ID)
_, exists := r.rules[ruleKey]
parsedIp, _, _ := net.ParseCIDR(pair.Source)
rules := r.rules
if parsedIp.To4() == nil {
rules = r.rules6
}
_, exists := rules[ruleKey]
if exists {
err := r.removeRoutingRule(format, pair)
if err != nil {
@@ -217,18 +293,35 @@ func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPai
}
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainName],
table, chain := r.workTable, r.chains[chainName]
if parsedIp.To4() == nil {
table, chain = r.workTable6, r.chains6[chainName]
}
newRule := r.conn.InsertRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: expression,
UserData: []byte(ruleKey),
})
if parsedIp.To4() == nil {
r.rules[ruleKey] = newRule
} else {
r.rules6[ruleKey] = newRule
}
return nil
}
func (r *router) acceptForwardRule(sourceNetwork string) {
src := generateCIDRMatcherExpressions(true, sourceNetwork)
dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0")
table := r.filterTable
parsedIp, _, _ := net.ParseCIDR(sourceNetwork)
if parsedIp.To4() == nil {
dst = generateCIDRMatcherExpressions(false, "::/0")
table = r.filterTable6
}
var exprs []expr.Any
exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic
@@ -236,10 +329,10 @@ func (r *router) acceptForwardRule(sourceNetwork string) {
})...)
rule := &nftables.Rule{
Table: r.filterTable,
Table: table,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
@@ -251,6 +344,9 @@ func (r *router) acceptForwardRule(sourceNetwork string) {
r.conn.AddRule(rule)
src = generateCIDRMatcherExpressions(true, "0.0.0.0/0")
if parsedIp.To4() == nil {
src = generateCIDRMatcherExpressions(true, "::/0")
}
dst = generateCIDRMatcherExpressions(false, sourceNetwork)
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
@@ -258,10 +354,10 @@ func (r *router) acceptForwardRule(sourceNetwork string) {
})...)
rule = &nftables.Rule{
Table: r.filterTable,
Table: table,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
@@ -270,12 +366,21 @@ func (r *router) acceptForwardRule(sourceNetwork string) {
UserData: []byte(userDataAcceptForwardRuleDst),
}
r.conn.AddRule(rule)
r.isDefaultFwdRulesEnabled = true
if parsedIp.To4() == nil {
r.isDefaultFwdRulesEnabled6 = true
} else {
r.isDefaultFwdRulesEnabled = true
}
}
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap()
parsedIp, _, _ := net.ParseCIDR(pair.Source)
if parsedIp.To4() == nil && r.workTable6 == nil {
return fmt.Errorf("nftables: attempted to remove IPv6 routing rule even though IPv6 is not enabled for this host")
}
err := r.refreshRulesMap(parsedIp.To4() == nil)
if err != nil {
return err
}
@@ -300,8 +405,12 @@ func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
return err
}
if len(r.rules) == 0 {
err := r.cleanUpDefaultForwardRules()
rulesList := r.rules
if parsedIp.To4() == nil {
rulesList = r.rules6
}
if len(rulesList) == 0 {
err := r.cleanUpDefaultForwardRules(parsedIp.To4() == nil)
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
@@ -319,7 +428,13 @@ func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error {
ruleKey := manager.GenKey(format, pair.ID)
rule, found := r.rules[ruleKey]
parsedIp, _, _ := net.ParseCIDR(pair.Source)
rules := r.rules
if parsedIp.To4() == nil {
rules = r.rules6
}
rule, found := rules[ruleKey]
if found {
ruleType := "forwarding"
if rule.Chain.Type == nftables.ChainTypeNAT {
@@ -333,49 +448,68 @@ func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination)
delete(r.rules, ruleKey)
delete(rules, ruleKey)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
func (r *router) refreshRulesMap() error {
for _, chain := range r.chains {
func (r *router) refreshRulesMap(forV6 bool) error {
chainList := r.chains
if forV6 {
chainList = r.chains6
}
for _, chain := range chainList {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("nftables: unable to list rules: %v", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
r.rules[string(rule.UserData)] = rule
if forV6 {
r.rules6[string(rule.UserData)] = rule
} else {
r.rules[string(rule.UserData)] = rule
}
}
}
}
return nil
}
func (r *router) cleanUpDefaultForwardRules() error {
if r.filterTable == nil {
r.isDefaultFwdRulesEnabled = false
func (r *router) cleanUpDefaultForwardRules(forV6 bool) error {
tableFamily := nftables.TableFamilyIPv4
filterTable := r.filterTable
if forV6 {
tableFamily = nftables.TableFamilyIPv6
filterTable = r.filterTable6
}
if filterTable == nil {
if forV6 {
r.isDefaultFwdRulesEnabled6 = false
} else {
r.isDefaultFwdRulesEnabled = false
}
return nil
}
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
chains, err := r.conn.ListChainsOfTableFamily(tableFamily)
if err != nil {
return err
}
var rules []*nftables.Rule
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name {
if chain.Table.Name != filterTable.Name {
continue
}
if chain.Name != "FORWARD" {
continue
}
rules, err = r.conn.GetRules(r.filterTable, chain)
rules, err = r.conn.GetRules(filterTable, chain)
if err != nil {
return err
}
@@ -389,7 +523,12 @@ func (r *router) cleanUpDefaultForwardRules() error {
}
}
}
r.isDefaultFwdRulesEnabled = false
if forV6 {
r.isDefaultFwdRulesEnabled6 = false
} else {
r.isDefaultFwdRulesEnabled = false
}
return r.conn.Flush()
}
@@ -405,6 +544,18 @@ func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any {
} else {
offSet = 16 // dst offset
}
addrLen := uint32(4)
zeroXor := zeroXor
if ip.To4() == nil {
if source {
offSet = 8 // src offset
} else {
offSet = 24 // dst offset
}
addrLen = 16
zeroXor = zeroXor6
}
return []expr.Any{
// fetch src add
@@ -412,13 +563,13 @@ func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any {
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offSet,
Len: 4,
Len: addrLen,
},
// net mask
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 4,
Len: addrLen,
Mask: network.Mask,
Xor: zeroXor,
},

View File

@@ -4,6 +4,7 @@ package nftables
import (
"context"
"github.com/netbirdio/netbird/iface"
"testing"
"github.com/coreos/go-iptables/iptables"
@@ -29,16 +30,19 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
t.Skip("nftables not supported on this OS")
}
table, err := createWorkTable()
table, table6, err := createWorkTables()
if err != nil {
t.Fatal(err)
}
defer deleteWorkTable()
defer deleteWorkTables()
for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(context.TODO(), table)
if testCase.IsV6 && table6 == nil {
t.Skip("Environment does not support IPv6, skipping IPv6 test...")
}
manager, err := newRouter(context.TODO(), table, table6)
require.NoError(t, err, "failed to create router")
nftablesTestingClient := &nftables.Conn{}
@@ -47,7 +51,7 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
require.NoError(t, err, "shouldn't return error")
err = manager.AddRoutingRules(testCase.InputPair)
err = manager.InsertRoutingRules(testCase.InputPair)
defer func() {
_ = manager.RemoveRoutingRules(testCase.InputPair)
}()
@@ -58,8 +62,13 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
chains := manager.chains
if testCase.IsV6 {
chains = manager.chains6
}
found := 0
for _, chain := range manager.chains {
for _, chain := range chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
@@ -75,7 +84,7 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
if testCase.InputPair.Masquerade {
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
found := 0
for _, chain := range manager.chains {
for _, chain := range chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
@@ -94,7 +103,7 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
found = 0
for _, chain := range manager.chains {
for _, chain := range chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
@@ -110,7 +119,7 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
if testCase.InputPair.Masquerade {
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
found := 0
for _, chain := range manager.chains {
for _, chain := range chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
@@ -131,16 +140,19 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
t.Skip("nftables not supported on this OS")
}
table, err := createWorkTable()
table, table6, err := createWorkTables()
if err != nil {
t.Fatal(err)
}
defer deleteWorkTable()
defer deleteWorkTables()
for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(context.TODO(), table)
if testCase.IsV6 && table6 == nil {
t.Skip("Environment does not support IPv6, skipping IPv6 test...")
}
manager, err := newRouter(context.TODO(), table, table6)
require.NoError(t, err, "failed to create router")
nftablesTestingClient := &nftables.Conn{}
@@ -150,11 +162,18 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
chains := manager.chains
workTable := table
if testCase.IsV6 {
chains = manager.chains6
workTable = table6
}
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRouteingFw],
Table: workTable,
Chain: chains[chainNameRouteingFw],
Exprs: forwardExp,
UserData: []byte(forwardRuleKey),
})
@@ -163,8 +182,8 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRoutingNat],
Table: workTable,
Chain: chains[chainNameRoutingNat],
Exprs: natExp,
UserData: []byte(natRuleKey),
})
@@ -175,8 +194,8 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRouteingFw],
Table: workTable,
Chain: chains[chainNameRouteingFw],
Exprs: forwardExp,
UserData: []byte(inForwardRuleKey),
})
@@ -185,8 +204,8 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRoutingNat],
Table: workTable,
Chain: chains[chainNameRoutingNat],
Exprs: natExp,
UserData: []byte(inNatRuleKey),
})
@@ -199,7 +218,7 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
err = manager.RemoveRoutingRules(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
for _, chain := range manager.chains {
for _, chain := range chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
@@ -238,30 +257,39 @@ func isIptablesClientAvailable(client *iptables.IPTables) bool {
return err == nil
}
func createWorkTable() (*nftables.Table, error) {
func createWorkTables() (*nftables.Table, *nftables.Table, error) {
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return nil, err
return nil, nil, err
}
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, err
return nil, nil, err
}
for _, t := range tables {
tables6, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6)
if err != nil {
return nil, nil, err
}
for _, t := range append(tables, tables6...) {
if t.Name == tableName {
sConn.DelTable(t)
}
}
table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
var table6 *nftables.Table
if iface.SupportsIPv6() {
table6 = sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6})
}
err = sConn.Flush()
return table, err
return table, table6, err
}
func deleteWorkTable() {
func deleteWorkTables() {
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return
@@ -272,6 +300,12 @@ func deleteWorkTable() {
return
}
tables6, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6)
if err != nil {
return
}
tables = append(tables, tables6...)
for _, t := range tables {
if t.Name == tableName {
sConn.DelTable(t)

View File

@@ -8,6 +8,7 @@ var (
InsertRuleTestCases = []struct {
Name string
InputPair firewall.RouterPair
IsV6 bool
}{
{
Name: "Insert Forwarding IPV4 Rule",
@@ -27,12 +28,32 @@ var (
Masquerade: true,
},
},
{
Name: "Insert Forwarding IPV6 Rule",
InputPair: firewall.RouterPair{
ID: "zxa",
Source: "2001:db8:0123:4567::1/128",
Destination: "2001:db8:0123:abcd::/64",
Masquerade: false,
},
IsV6: true,
},
{
Name: "Insert Forwarding And Nat IPV6 Rules",
InputPair: firewall.RouterPair{
ID: "zxa",
Source: "2001:db8:0123:4567::1/128",
Destination: "2001:db8:0123:abcd::/64",
Masquerade: true,
},
IsV6: true,
},
}
RemoveRuleTestCases = []struct {
Name string
InputPair firewall.RouterPair
IpVersion string
IsV6 bool
}{
{
Name: "Remove Forwarding And Nat IPV4 Rules",
@@ -43,5 +64,15 @@ var (
Masquerade: true,
},
},
{
Name: "Remove Forwarding And Nat IPV6 Rules",
InputPair: firewall.RouterPair{
ID: "zxa",
Source: "2001:db8:0123:4567::1/128",
Destination: "2001:db8:0123:abcd::/64",
Masquerade: true,
},
IsV6: true,
},
}
)

View File

@@ -24,6 +24,7 @@ var (
type IFaceMapper interface {
SetFilter(iface.PacketFilter) error
Address() iface.WGAddress
Address6() *iface.WGAddress
}
// RuleSet is a set of rules grouped by a string key
@@ -69,6 +70,14 @@ func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager
return mgr, nil
}
func (m *Manager) ResetV6Firewall() error {
return nil
}
func (m *Manager) V6Active() bool {
return false
}
func create(iface IFaceMapper) (*Manager, error) {
m := &Manager{
decoders: sync.Pool{

View File

@@ -33,6 +33,10 @@ func (i *IFaceMock) Address() iface.WGAddress {
return i.AddressFunc()
}
func (i *IFaceMock) Address6() *iface.WGAddress {
return nil
}
func TestManagerCreate(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(iface.PacketFilter) error { return nil },

View File

@@ -16,9 +16,10 @@ import (
mgmProto "github.com/netbirdio/netbird/management/proto"
)
// Manager is a ACL rules manager
// Manager is an ACL rules manager
type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap)
ResetV6Acl() error
}
// DefaultManager uses firewall manager to handle
@@ -26,16 +27,36 @@ type DefaultManager struct {
firewall firewall.Manager
ipsetCounter int
rulesPairs map[string][]firewall.Rule
rulesPairs6 map[string][]firewall.Rule
mutex sync.Mutex
}
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
return &DefaultManager{
firewall: fm,
rulesPairs: make(map[string][]firewall.Rule),
firewall: fm,
rulesPairs: make(map[string][]firewall.Rule),
rulesPairs6: make(map[string][]firewall.Rule),
}
}
func (d *DefaultManager) ResetV6Acl() error {
for _, rules := range d.rulesPairs6 {
for _, r := range rules {
err := d.firewall.DeleteRule(r)
if err != nil {
return err
}
}
}
err := d.firewall.ResetV6Firewall()
if err != nil {
return err
}
d.rulesPairs6 = make(map[string][]firewall.Rule)
return nil
}
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
//
// If allowByDefault is true it appends allow ALL traffic rules to input and output chains.
@@ -83,6 +104,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
if enableSSH {
rules = append(rules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
PeerIP6: "::",
Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_TCP,
@@ -97,12 +119,14 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
rules = append(rules,
&mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
PeerIP6: "::",
Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_ALL,
},
&mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
PeerIP6: "::",
Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_ALL,
@@ -111,6 +135,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
}
newRulePairs := make(map[string][]firewall.Rule)
newRulePairs6 := make(map[string][]firewall.Rule)
ipsetByRuleSelectors := make(map[string]string)
for _, r := range rules {
@@ -123,7 +148,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
ipsetName = fmt.Sprintf("nb%07d", d.ipsetCounter)
ipsetByRuleSelectors[selector] = ipsetName
}
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
pairID, rulePair, rulePair6, err := d.protoRuleToFirewallRule(r, ipsetName)
if err != nil {
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
d.rollBack(newRulePairs)
@@ -132,6 +157,8 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
if len(rules) > 0 {
d.rulesPairs[pairID] = rulePair
newRulePairs[pairID] = rulePair
d.rulesPairs6[pairID] = rulePair6
newRulePairs6[pairID] = rulePair6
}
}
@@ -146,59 +173,104 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
delete(d.rulesPairs, pairID)
}
}
for pairID, rules := range d.rulesPairs6 {
if _, ok := newRulePairs6[pairID]; !ok {
for _, rule := range rules {
if err := d.firewall.DeleteRule(rule); err != nil {
log.Errorf("failed to delete firewall rule: %v", err)
continue
}
}
delete(d.rulesPairs6, pairID)
}
}
d.rulesPairs = newRulePairs
d.rulesPairs6 = newRulePairs6
}
func (d *DefaultManager) protoRuleToFirewallRule(
r *mgmProto.FirewallRule,
ipsetName string,
) (string, []firewall.Rule, error) {
) (string, []firewall.Rule, []firewall.Rule, error) {
ip := net.ParseIP(r.PeerIP)
if ip == nil {
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
return "", nil, nil, fmt.Errorf("invalid IP address, skipping firewall rule")
}
var ip6 *net.IP = nil
if d.firewall.V6Active() && r.PeerIP6 != "" {
ip6tmp := net.ParseIP(r.PeerIP6)
if ip6tmp == nil {
return "", nil, nil, fmt.Errorf("invalid IP address, skipping firewall rule")
}
ip6 = &ip6tmp
}
protocol, err := convertToFirewallProtocol(r.Protocol)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
return "", nil, nil, fmt.Errorf("skipping firewall rule: %s", err)
}
action, err := convertFirewallAction(r.Action)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
return "", nil, nil, fmt.Errorf("skipping firewall rule: %s", err)
}
var port *firewall.Port
if r.Port != "" {
value, err := strconv.Atoi(r.Port)
if err != nil {
return "", nil, fmt.Errorf("invalid port, skipping firewall rule")
return "", nil, nil, fmt.Errorf("invalid port, skipping firewall rule")
}
port = &firewall.Port{
Values: []int{value},
}
}
ruleID := d.getRuleID(ip, protocol, int(r.Direction), port, action, "")
var rules []firewall.Rule
var rules6 []firewall.Rule
ruleID := d.getRuleID(ip, ip6, protocol, int(r.Direction), port, action, "")
if rulesPair, ok := d.rulesPairs[ruleID]; ok {
return ruleID, rulesPair, nil
rules = rulesPair
}
if rulesPair6, ok := d.rulesPairs6[ruleID]; d.firewall.V6Active() && ok && ip6 != nil {
rules6 = rulesPair6
}
var rules []firewall.Rule
switch r.Direction {
case mgmProto.FirewallRule_IN:
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
case mgmProto.FirewallRule_OUT:
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
if rules == nil {
switch r.Direction {
case mgmProto.FirewallRule_IN:
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
case mgmProto.FirewallRule_OUT:
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
default:
return "", nil, nil, fmt.Errorf("invalid direction, skipping firewall rule")
}
}
if err != nil {
return "", nil, err
return "", nil, nil, err
}
return ruleID, rules, nil
if d.firewall.V6Active() && ip6 != nil && rules6 == nil {
switch r.Direction {
case mgmProto.FirewallRule_IN:
rules6, err = d.addInRules(*ip6, protocol, port, action, ipsetName, "")
case mgmProto.FirewallRule_OUT:
rules6, err = d.addOutRules(*ip6, protocol, port, action, ipsetName, "")
default:
return "", nil, nil, fmt.Errorf("invalid direction, skipping firewall rule")
}
}
if err != nil && err.Error() != "failed to add firewall rule: attempted to configure filtering for IPv6 address even though IPv6 is not active" {
return "", rules, nil, err
}
return ruleID, rules, rules6, nil
}
func (d *DefaultManager) addInRules(
@@ -226,8 +298,9 @@ func (d *DefaultManager) addInRules(
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
rules = append(rules, rule...)
return append(rules, rule...), nil
return rules, nil
}
func (d *DefaultManager) addOutRules(
@@ -255,20 +328,26 @@ func (d *DefaultManager) addOutRules(
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
rules = append(rules, rule...)
return append(rules, rule...), nil
return rules, nil
}
// getRuleID() returns unique ID for the rule based on its parameters.
func (d *DefaultManager) getRuleID(
ip net.IP,
ip6 *net.IP,
proto firewall.Protocol,
direction int,
port *firewall.Port,
action firewall.Action,
comment string,
) string {
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
ip6Str := ""
if ip6 != nil {
ip6Str = ip6.String()
}
idStr := ip.String() + ip6Str + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
if port != nil {
idStr += port.String()
}
@@ -321,6 +400,8 @@ func (d *DefaultManager) squashAcceptRules(
// it means that rules for that protocol was already optimized on the
// management side
if r.PeerIP == "0.0.0.0" {
// I don't _think_ that IPv6 is relevant here, as any optimization that has r.PeerIP6 == "::" should also
// implicitly have r.PeerIP == "0.0.0.0".
squashedRules = append(squashedRules, r)
squashedProtocols[r.Protocol] = struct{}{}
return
@@ -364,6 +445,7 @@ func (d *DefaultManager) squashAcceptRules(
// add special rule 0.0.0.0 which allows all IP's in our firewall implementations
squashedRules = append(squashedRules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
PeerIP6: "::",
Direction: direction,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: protocol,

View File

@@ -19,6 +19,7 @@ func TestDefaultManager(t *testing.T) {
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
PeerIP6: "2001:db8::fedc:ba09:8765:0001",
Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_TCP,
@@ -26,6 +27,7 @@ func TestDefaultManager(t *testing.T) {
},
{
PeerIP: "10.93.0.2",
PeerIP6: "2001:db8::fedc:ba09:8765:0002",
Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.FirewallRule_DROP,
Protocol: mgmProto.FirewallRule_UDP,
@@ -50,6 +52,14 @@ func TestDefaultManager(t *testing.T) {
IP: ip,
Network: network,
}).AnyTimes()
ip6, network6, err := net.ParseCIDR("2001:db8::fedc:ba09:8765:4321/64")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Address6().Return(&iface.WGAddress{
IP: ip6,
Network: network6,
}).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
@@ -83,6 +93,7 @@ func TestDefaultManager(t *testing.T) {
networkMap.FirewallRules,
&mgmProto.FirewallRule{
PeerIP: "10.93.0.3",
PeerIP6: "2001:db8::fedc:ba09:8765:0003",
Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.FirewallRule_DROP,
Protocol: mgmProto.FirewallRule_ICMP,
@@ -343,6 +354,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
IP: ip,
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().Address6().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)

View File

@@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/netbirdio/netbird/client/internal/acl (interfaces: IFaceMapper)
// Source: ./client/firewall/iface.go
// Package mocks is a generated GoMock package.
package mocks
@@ -48,6 +48,20 @@ func (mr *MockIFaceMapperMockRecorder) Address() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Address", reflect.TypeOf((*MockIFaceMapper)(nil).Address))
}
// Address6 mocks base method.
func (m *MockIFaceMapper) Address6() *iface.WGAddress {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Address6")
ret0, _ := ret[0].(*iface.WGAddress)
return ret0
}
// Address6 indicates an expected call of Address6.
func (mr *MockIFaceMapperMockRecorder) Address6() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Address6", reflect.TypeOf((*MockIFaceMapper)(nil).Address6))
}
// IsUserspaceBind mocks base method.
func (m *MockIFaceMapper) IsUserspaceBind() bool {
m.ctrl.T.Helper()

View File

@@ -6,16 +6,13 @@ import (
"net/url"
"os"
"reflect"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client"
@@ -56,7 +53,6 @@ type ConfigInput struct {
NetworkMonitor *bool
DisableAutoConnect *bool
ExtraIFaceBlackList []string
DNSRouteInterval *time.Duration
}
// Config Configuration type
@@ -68,7 +64,7 @@ type Config struct {
AdminURL *url.URL
WgIface string
WgPort int
NetworkMonitor *bool
NetworkMonitor bool
IFaceBlackList []string
DisableIPv6Discovery bool
RosenpassEnabled bool
@@ -99,9 +95,6 @@ type Config struct {
// DisableAutoConnect determines whether the client should not start with the service
// it's set to false by default due to backwards compatibility
DisableAutoConnect bool
// DNSRouteInterval is the interval in which the DNS routes are updated
DNSRouteInterval time.Duration
}
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
@@ -311,21 +304,12 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.NetworkMonitor != nil && input.NetworkMonitor != config.NetworkMonitor {
if input.NetworkMonitor != nil && *input.NetworkMonitor != config.NetworkMonitor {
log.Infof("switching Network Monitor to %t", *input.NetworkMonitor)
config.NetworkMonitor = input.NetworkMonitor
config.NetworkMonitor = *input.NetworkMonitor
updated = true
}
if config.NetworkMonitor == nil {
// enable network monitoring by default on windows and darwin clients
if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
enabled := true
config.NetworkMonitor = &enabled
updated = true
}
}
if input.CustomDNSAddress != nil && string(input.CustomDNSAddress) != config.CustomDNSAddress {
log.Infof("updating custom DNS address %#v (old value %#v)",
string(input.CustomDNSAddress), config.CustomDNSAddress)
@@ -373,18 +357,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
log.Infof("updating DNS route interval to %s (old value %s)",
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
config.DNSRouteInterval = *input.DNSRouteInterval
updated = true
} else if config.DNSRouteInterval == 0 {
config.DNSRouteInterval = dynamic.DefaultInterval
log.Infof("using default DNS route interval %s", config.DNSRouteInterval)
updated = true
}
return updated, nil
}

View File

@@ -252,10 +252,8 @@ func (c *ConnectClient) run(
return wrapErr(err)
}
checks := loginResp.GetChecks()
c.engineMutex.Lock()
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, checks)
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe)
c.engineMutex.Unlock()
err = c.engine.Start()
@@ -309,25 +307,21 @@ func (c *ConnectClient) Engine() *Engine {
// createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
nm := false
if config.NetworkMonitor != nil {
nm = *config.NetworkMonitor
}
engineConf := &EngineConfig{
WgIfaceName: config.WgIface,
WgAddr: peerConfig.Address,
WgAddr6: peerConfig.Address6,
IFaceBlackList: config.IFaceBlackList,
DisableIPv6Discovery: config.DisableIPv6Discovery,
WgPrivateKey: key,
WgPort: config.WgPort,
NetworkMonitor: nm,
NetworkMonitor: config.NetworkMonitor,
SSHKey: []byte(config.SSHKey),
NATExternalIPs: config.NATExternalIPs,
CustomDNSAddress: config.CustomDNSAddress,
RosenpassEnabled: config.RosenpassEnabled,
RosenpassPermissive: config.RosenpassPermissive,
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
DNSRouteInterval: config.DNSRouteInterval,
}
if config.PreSharedKey != "" {

View File

@@ -1,6 +0,0 @@
package dns
const (
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/db/netbird/manager"
)

View File

@@ -1,8 +0,0 @@
//go:build !android
package dns
const (
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
)

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build !android
package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build !android
package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build !android
package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build !android
package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build !android
package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build !android
package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build !android
package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build !android
package dns
@@ -108,7 +108,7 @@ func getOSDNSManagerType() (osManagerType, error) {
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
return networkManager, nil
}
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
if checkStub() {
return systemdManager, nil
} else {
@@ -116,10 +116,16 @@ func getOSDNSManagerType() (osManagerType, error) {
}
}
if strings.Contains(text, "resolvconf") {
if isSystemdResolveConfMode() {
return systemdManager, nil
if isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
var value string
err = getSystemdDbusProperty(systemdDbusResolvConfModeProperty, &value)
if err == nil {
if value == systemdDbusResolvConfModeForeign {
return systemdManager, nil
}
}
log.Errorf("got an error while checking systemd resolv conf mode, error: %s", err)
}
return resolvConfManager, nil
}
}

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build !android
package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build !android
package dns

View File

@@ -485,7 +485,11 @@ func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord
}
func getNSHostPort(ns nbdns.NameServer) string {
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
if ns.IP.Is4() {
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
} else {
return fmt.Sprintf("[%s]:%d", ns.IP.String(), ns.Port)
}
}
// upstreamCallbacks returns two functions, the first one is used to deactivate

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build !android
package dns

View File

@@ -38,9 +38,12 @@ func (w *mocWGIface) Address() iface.WGAddress {
Network: network,
}
}
func (w *mocWGIface) ToInterface() *net.Interface {
panic("implement me")
func (w *mocWGIface) Address6() *iface.WGAddress {
ip, network, _ := net.ParseCIDR("fd00:1234:dead:beef::/64")
return &iface.WGAddress{
IP: ip,
Network: network,
}
}
func (w *mocWGIface) GetFilter() iface.PacketFilter {
@@ -265,7 +268,7 @@ func TestUpdateDNSServer(t *testing.T) {
if err != nil {
t.Fatal(err)
}
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), fmt.Sprintf("fd00:1234:dead:beef::%d/128", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
if err != nil {
t.Fatal(err)
}
@@ -343,7 +346,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
}
privKey, _ := wgtypes.GeneratePrivateKey()
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", "", 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
if err != nil {
t.Errorf("build interface wireguard: %v", err)
return
@@ -599,7 +602,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
}
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
wgIFace, err := createWgInterfaceWithBind(t)
wgIFace, err := createWgInterfaceWithBind(t, false)
if err != nil {
t.Fatal("failed to initialize wg interface")
}
@@ -625,7 +628,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
}
func TestDNSPermanent_updateUpstream(t *testing.T) {
wgIFace, err := createWgInterfaceWithBind(t)
wgIFace, err := createWgInterfaceWithBind(t, false)
if err != nil {
t.Fatal("failed to initialize wg interface")
}
@@ -717,7 +720,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
}
func TestDNSPermanent_matchOnly(t *testing.T) {
wgIFace, err := createWgInterfaceWithBind(t)
wgIFace, err := createWgInterfaceWithBind(t, false)
if err != nil {
t.Fatal("failed to initialize wg interface")
}
@@ -788,7 +791,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
}
}
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
func createWgInterfaceWithBind(t *testing.T, enableV6 bool) (*iface.WGIface, error) {
t.Helper()
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
@@ -801,7 +804,11 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
}
privKey, _ := wgtypes.GeneratePrivateKey()
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
v6Addr := ""
if enableV6 {
v6Addr = "fd00:1234:dead:beef::1/128"
}
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", v6Addr, 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
if err != nil {
t.Fatalf("build interface wireguard: %v", err)
return nil, err

View File

@@ -1,20 +0,0 @@
package dns
import (
"errors"
"fmt"
)
var errNotImplemented = errors.New("not implemented")
func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented)
}
func isSystemdResolvedRunning() bool {
return false
}
func isSystemdResolveConfMode() bool {
return false
}

View File

@@ -242,25 +242,3 @@ func getSystemdDbusProperty(property string, store any) error {
return v.Store(store)
}
func isSystemdResolvedRunning() bool {
return isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode)
}
func isSystemdResolveConfMode() bool {
if !isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
return false
}
var value string
if err := getSystemdDbusProperty(systemdDbusResolvConfModeProperty, &value); err != nil {
log.Errorf("got an error while checking systemd resolv conf mode, error: %s", err)
return false
}
if value == systemdDbusResolvConfModeForeign {
return true
}
return false
}

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build !android
package dns
@@ -14,6 +14,11 @@ import (
log "github.com/sirupsen/logrus"
)
const (
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
)
func CheckUncleanShutdown(wgIface string) error {
if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil {
if errors.Is(err, fs.ErrNotExist) {

View File

@@ -78,11 +78,6 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}()
log.WithField("question", r.Question[0]).Trace("received an upstream question")
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
if r.Extra == nil {
r.SetEdns0(4096, false)
r.MsgHdr.AuthenticatedData = true
}
select {
case <-u.ctx.Done():

View File

@@ -2,17 +2,12 @@
package dns
import (
"net"
"github.com/netbirdio/netbird/iface"
)
import "github.com/netbirdio/netbird/iface"
// WGIface defines subset methods of interface required for manager
type WGIface interface {
Name() string
Address() iface.WGAddress
ToInterface() *net.Interface
IsUserspaceBind() bool
GetFilter() iface.PacketFilter
GetDevice() *iface.DeviceWrapper

View File

@@ -10,7 +10,6 @@ import (
"net/netip"
"reflect"
"runtime"
"slices"
"strings"
"sync"
"time"
@@ -29,21 +28,17 @@ import (
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/wgproxy"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/iface/bind"
mgm "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto"
"github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/util/net"
)
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
@@ -63,7 +58,8 @@ type EngineConfig struct {
WgIfaceName string
// WgAddr is a Wireguard local address (Netbird Network IP)
WgAddr string
WgAddr string
WgAddr6 string
// WgPrivateKey is a Wireguard private key of our peer (it MUST never leave the machine)
WgPrivateKey wgtypes.Key
@@ -94,8 +90,6 @@ type EngineConfig struct {
RosenpassPermissive bool
ServerSSHAllowed bool
DNSRouteInterval time.Duration
}
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@@ -107,8 +101,8 @@ type Engine struct {
// peerConns is a map that holds all the peers that are known to this peer
peerConns map[string]*peer.Conn
beforePeerHook nbnet.AddHookFunc
afterPeerHook nbnet.RemoveHookFunc
beforePeerHook peer.BeforeAddPeerHookFunc
afterPeerHook peer.AfterRemovePeerHookFunc
// rpManager is a Rosenpass manager
rpManager *rosenpass.Manager
@@ -161,9 +155,6 @@ type Engine struct {
wgProbe *Probe
wgConnWorker sync.WaitGroup
// checks are the client-applied posture checks that need to be evaluated on the client
checks []*mgmProto.Checks
}
// Peer is an instance of the Connection Peer
@@ -181,7 +172,6 @@ func NewEngine(
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
checks []*mgmProto.Checks,
) *Engine {
return NewEngineWithProbes(
clientCtx,
@@ -195,7 +185,6 @@ func NewEngine(
nil,
nil,
nil,
checks,
)
}
@@ -212,7 +201,6 @@ func NewEngineWithProbes(
signalProbe *Probe,
relayProbe *Probe,
wgProbe *Probe,
checks []*mgmProto.Checks,
) *Engine {
return &Engine{
@@ -233,7 +221,6 @@ func NewEngineWithProbes(
signalProbe: signalProbe,
relayProbe: relayProbe,
wgProbe: wgProbe,
checks: checks,
}
}
@@ -282,6 +269,8 @@ func (e *Engine) Start() error {
}
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, e.config.WgPort)
wgIface, err := e.newWgIface()
if err != nil {
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err)
@@ -289,9 +278,6 @@ func (e *Engine) Start() error {
}
e.wgInterface = wgIface
userspace := e.wgInterface.IsUserspaceBind()
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort)
if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled")
if e.config.RosenpassPermissive {
@@ -316,7 +302,7 @@ func (e *Engine) Start() error {
}
e.dnsServer = dnsServer
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, initialRoutes)
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes)
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
if err != nil {
log.Errorf("Failed to initialize route manager: %s", err)
@@ -542,10 +528,6 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
// todo update signal
}
if err := e.updateChecksIfNew(update.Checks); err != nil {
return err
}
if update.GetNetworkMap() != nil {
// only apply new changes and ignore old ones
err := e.updateNetworkMap(update.GetNetworkMap())
@@ -553,27 +535,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return err
}
}
return nil
}
// updateChecksIfNew updates checks if there are changes and sync new meta with management
func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
// if checks are equal, we skip the update
if isChecksEqual(e.checks, checks) {
return nil
}
e.checks = checks
info, err := system.GetInfoWithChecks(e.ctx, checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
}
if err := e.mgmClient.SyncMeta(info); err != nil {
log.Errorf("could not sync meta: error %s", err)
return err
}
return nil
}
@@ -589,8 +551,8 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
} else {
if sshConf.GetSshEnabled() {
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" {
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
if runtime.GOOS == "windows" {
log.Warnf("running SSH server on Windows is not supported")
return nil
}
// start SSH server if it wasn't running
@@ -642,6 +604,32 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
log.Infof("updated peer address from %s to %s", oldAddr, conf.Address)
}
if e.wgInterface.Address6() == nil && conf.Address6 != "" ||
e.wgInterface.Address6() != nil && e.wgInterface.Address6().String() != conf.Address6 {
oldAddr := "none"
if e.wgInterface.Address6() != nil {
oldAddr = e.wgInterface.Address6().String()
}
newAddr := "none"
if conf.Address6 != "" {
newAddr = conf.Address6
}
log.Debugf("updating peer IPv6 address from %s to %s", oldAddr, newAddr)
err := e.wgInterface.UpdateAddr6(conf.Address6)
if err != nil {
return err
}
e.config.WgAddr6 = conf.Address6
err = e.acl.ResetV6Acl()
if err != nil {
return err
}
e.routeManager.ResetV6Routes()
log.Infof("updated peer IPv6 address from %s to %s", oldAddr, conf.Address6)
}
if conf.GetSshConfig() != nil {
err := e.updateSSH(conf.GetSshConfig())
if err != nil {
@@ -651,6 +639,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{
IP: e.config.WgAddr,
IP6: e.config.WgAddr6,
PubKey: e.config.WgPrivateKey.PublicKey().String(),
KernelInterface: iface.WireGuardModuleIsLoaded(),
FQDN: conf.GetFqdn(),
@@ -663,14 +652,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
// E.g. when a new peer has been registered and we are allowed to connect to it.
func (e *Engine) receiveManagementEvents() {
go func() {
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
}
// err = e.mgmClient.Sync(info, e.handleSync)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
err := e.mgmClient.Sync(e.ctx, e.handleSync)
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
@@ -737,20 +719,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return nil
}
protoRoutes := networkMap.GetRoutes()
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
if err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
}
e.clientRoutesMu.Lock()
e.clientRoutes = clientRoutes
e.clientRoutesMu.Unlock()
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
e.updateOfflinePeers(networkMap.GetOfflinePeers())
@@ -792,6 +760,19 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
}
}
}
protoRoutes := networkMap.GetRoutes()
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
if err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
}
e.clientRoutesMu.Lock()
e.clientRoutes = clientRoutes
e.clientRoutesMu.Unlock()
protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil {
@@ -819,24 +800,15 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
routes := make([]*route.Route, 0)
for _, protoRoute := range protoRoutes {
var prefix netip.Prefix
if len(protoRoute.Domains) == 0 {
var err error
if prefix, err = netip.ParsePrefix(protoRoute.Network); err != nil {
log.Errorf("Failed to parse prefix %s: %v", protoRoute.Network, err)
continue
}
}
_, prefix, _ := route.ParseNetwork(protoRoute.Network)
convertedRoute := &route.Route{
ID: route.ID(protoRoute.ID),
Network: prefix,
Domains: domain.FromPunycodeList(protoRoute.Domains),
NetID: route.NetID(protoRoute.NetID),
NetworkType: route.NetworkType(protoRoute.NetworkType),
Peer: protoRoute.Peer,
Metric: int(protoRoute.Metric),
Masquerade: protoRoute.Masquerade,
KeepRoute: protoRoute.KeepRoute,
}
routes = append(routes, convertedRoute)
}
@@ -1039,6 +1011,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
WgConfig: wgConfig,
LocalWgPort: e.config.WgPort,
NATExternalIPs: e.parseNATExternalIPMappings(),
UserspaceBind: e.wgInterface.IsUserspaceBind(),
RosenpassPubKey: e.getRosenpassPubKey(),
RosenpassAddr: e.getRosenpassAddr(),
}
@@ -1101,6 +1074,8 @@ func (e *Engine) receiveSignalEvents() {
return err
}
conn.RegisterProtoSupportMeta(msg.Body.GetFeaturesSupported())
var rosenpassPubKey []byte
rosenpassAddr := ""
if msg.GetBody().GetRosenpassConfig() != nil {
@@ -1123,6 +1098,8 @@ func (e *Engine) receiveSignalEvents() {
return err
}
conn.RegisterProtoSupportMeta(msg.GetBody().GetFeaturesSupported())
var rosenpassPubKey []byte
rosenpassAddr := ""
if msg.GetBody().GetRosenpassConfig() != nil {
@@ -1260,8 +1237,7 @@ func (e *Engine) close() {
}
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
info := system.GetInfo(e.ctx)
netMap, err := e.mgmClient.GetNetworkMap(info)
netMap, err := e.mgmClient.GetNetworkMap()
if err != nil {
return nil, nil, err
}
@@ -1290,7 +1266,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
default:
}
return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs, e.addrViaRoutes)
return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgAddr6, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs)
}
func (e *Engine) wgInterfaceCreate() (err error) {
@@ -1465,15 +1441,6 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs)
}
func (e *Engine) restartEngine() {
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
if err := e.Start(); err != nil {
log.Errorf("Failed to start engine: %v", err)
}
}
func (e *Engine) startNetworkMonitor() {
if !e.config.NetworkMonitor {
log.Infof("Network monitor is disabled, not starting")
@@ -1482,54 +1449,17 @@ func (e *Engine) startNetworkMonitor() {
e.networkMonitor = networkmonitor.New()
go func() {
var mu sync.Mutex
var debounceTimer *time.Timer
// Start the network monitor with a callback, Start will block until the monitor is stopped,
// a network change is detected, or an error occurs on start up
err := e.networkMonitor.Start(e.ctx, func() {
// This function is called when a network change is detected
mu.Lock()
defer mu.Unlock()
if debounceTimer != nil {
debounceTimer.Stop()
log.Infof("Network monitor detected network change, restarting engine")
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
if err := e.Start(); err != nil {
log.Errorf("Failed to start engine: %v", err)
}
// Set a new timer to debounce rapid network changes
debounceTimer = time.AfterFunc(1*time.Second, func() {
// This function is called after the debounce period
mu.Lock()
defer mu.Unlock()
log.Infof("Network monitor detected network change, restarting engine")
e.restartEngine()
})
})
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
log.Errorf("Network monitor: %v", err)
}
}()
}
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
var vpnRoutes []netip.Prefix
for _, routes := range e.GetClientRoutes() {
if len(routes) > 0 && routes[0] != nil {
vpnRoutes = append(vpnRoutes, routes[0].Network)
}
}
if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn {
return true, prefix, nil
}
return false, netip.Prefix{}, nil
}
// isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files)
})
}

View File

@@ -17,7 +17,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
@@ -58,9 +57,9 @@ var (
)
func TestEngine_SSH(t *testing.T) {
// todo resolve test execution on freebsd
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" {
t.Skip("skipping TestEngine_SSH")
if runtime.GOOS == "windows" {
t.Skip("skipping TestEngine_SSH on Windows")
}
key, err := wgtypes.GeneratePrivateKey()
@@ -78,7 +77,7 @@ func TestEngine_SSH(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
ServerSSHAllowed: true,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -174,7 +173,7 @@ func TestEngine_SSH(t *testing.T) {
t.Fatal(err)
}
// time.Sleep(250 * time.Millisecond)
//time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=")
@@ -212,16 +211,16 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", "", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil)
if err != nil {
t.Fatal(err)
}
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, nil)
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder, nil)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
@@ -394,7 +393,7 @@ func TestEngine_Sync(t *testing.T) {
// feed updates to Engine via mocked Management client
updates := make(chan *mgmtProto.SyncResponse)
defer close(updates)
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
syncFunc := func(ctx context.Context, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
for msg := range updates {
err := msgHandler(msg)
if err != nil {
@@ -409,7 +408,7 @@ func TestEngine_Sync(t *testing.T) {
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.ctx = ctx
engine.dnsServer = &dns.MockServer{
@@ -566,15 +565,16 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
WgIfaceName: wgIfaceName,
WgAddr: wgAddr,
WgAddr6: "",
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.ctx = ctx
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgAddr6, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil)
assert.NoError(t, err, "shouldn't return error")
input := struct {
inputSerial uint64
@@ -736,16 +736,17 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
WgIfaceName: wgIfaceName,
WgAddr: wgAddr,
WgAddr6: "",
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.ctx = ctx
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil, nil)
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgAddr6, 33100, key.String(), iface.DefaultMTU, newNet, nil)
assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{
@@ -811,13 +812,13 @@ func TestEngine_MultiplePeers(t *testing.T) {
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
sigServer, signalAddr, err := startSignal(t)
sigServer, signalAddr, err := startSignal()
if err != nil {
t.Fatal(err)
return
}
defer sigServer.Stop()
mgmtServer, mgmtAddr, err := startManagement(t, dir)
mgmtServer, mgmtAddr, err := startManagement(dir)
if err != nil {
t.Fatal(err)
return
@@ -1009,14 +1010,12 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
WgPort: wgPort,
}
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm")), nil
e.ctx = ctx
return e, err
}
func startSignal(t *testing.T) (*grpc.Server, string, error) {
t.Helper()
func startSignal() (*grpc.Server, string, error) {
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
lis, err := net.Listen("tcp", "localhost:0")
@@ -1024,9 +1023,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
log.Fatalf("failed to listen: %v", err)
}
srv, err := signalServer.NewServer(otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)
proto.RegisterSignalExchangeServer(s, signalServer.NewServer())
go func() {
if err = s.Serve(lis); err != nil {
@@ -1037,9 +1034,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
return s, lis.Addr().String(), nil
}
func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) {
t.Helper()
func startManagement(dataDir string) (*grpc.Server, string, error) {
config := &server.Config{
Stuns: []*server.Host{},
TURNConfig: &server.TURNConfig{},
@@ -1056,25 +1051,23 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
store, _, err := server.NewTestStoreFromJson(config.Datadir)
if err != nil {
return nil, "", err
}
t.Cleanup(cleanUp)
peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
ia, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
if err != nil {
return nil, "", err
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
return nil, "", err
}

View File

@@ -5,6 +5,8 @@ package networkmonitor
import (
"context"
"fmt"
"net"
"net/netip"
"syscall"
"unsafe"
@@ -12,10 +14,10 @@ import (
"golang.org/x/net/route"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/routemanager"
)
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthopv6 netip.Addr, intfv6 *net.Interface, callback func()) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
return fmt.Errorf("failed to open routing socket: %v", err)
@@ -45,6 +47,24 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
switch msg.Type {
// handle interface state changes
case unix.RTM_IFINFO:
ifinfo, err := parseInterfaceMessage(buf[:n])
if err != nil {
log.Errorf("Network monitor: error parsing interface message: %v", err)
continue
}
if msg.Flags&unix.IFF_UP != 0 {
continue
}
if (intfv4 == nil || ifinfo.Index != intfv4.Index) && (intfv6 == nil || ifinfo.Index != intfv6.Index) {
continue
}
log.Infof("Network monitor: monitored interface (%s) is down.", ifinfo.Name)
go callback()
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
@@ -66,7 +86,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
go callback()
case unix.RTM_DELETE:
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
if intfv4 != nil && route.Gw.Compare(nexthopv4) == 0 || intfv6 != nil && route.Gw.Compare(nexthopv6) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
go callback()
}
@@ -76,7 +96,25 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
}
}
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
func parseInterfaceMessage(buf []byte) (*route.InterfaceMessage, error) {
msgs, err := route.ParseRIB(route.RIBTypeInterface, buf)
if err != nil {
return nil, fmt.Errorf("parse RIB: %v", err)
}
if len(msgs) != 1 {
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
}
msg, ok := msgs[0].(*route.InterfaceMessage)
if !ok {
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
}
return msg, nil
}
func parseRouteMessage(buf []byte) (*routemanager.Route, error) {
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
if err != nil {
return nil, fmt.Errorf("parse RIB: %v", err)
@@ -91,5 +129,5 @@ func parseRouteMessage(buf []byte) (*systemops.Route, error) {
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
}
return systemops.MsgToRoute(msg)
return routemanager.MsgToRoute(msg)
}

View File

@@ -6,13 +6,14 @@ import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"runtime/debug"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/routemanager"
)
// Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns.
@@ -28,22 +29,23 @@ func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error
nw.wg.Add(1)
defer nw.wg.Done()
var nexthop4, nexthop6 systemops.Nexthop
var nexthop4, nexthop6 netip.Addr
var intf4, intf6 *net.Interface
operation := func() error {
var errv4, errv6 error
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
nexthop4, intf4, errv4 = routemanager.GetNextHop(netip.IPv4Unspecified())
nexthop6, intf6, errv6 = routemanager.GetNextHop(netip.IPv6Unspecified())
if errv4 != nil && errv6 != nil {
return errors.New("failed to get default next hops")
}
if errv4 == nil {
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4, intf4.Name)
}
if errv6 == nil {
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6, intf6.Name)
}
// continue if either route was found
@@ -63,7 +65,7 @@ func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error
}
}()
if err := checkChange(ctx, nexthop4, nexthop6, callback); err != nil {
if err := checkChange(ctx, nexthop4, intf4, nexthop6, intf6, callback); err != nil {
return fmt.Errorf("check change: %w", err)
}

View File

@@ -6,22 +6,27 @@ import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
if nexthopv4.Intf == nil && nexthopv6.Intf == nil {
func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthop6 netip.Addr, intfv6 *net.Interface, callback func()) error {
if intfv4 == nil && intfv6 == nil {
return errors.New("no interfaces available")
}
linkChan := make(chan netlink.LinkUpdate)
done := make(chan struct{})
defer close(done)
if err := netlink.LinkSubscribe(linkChan, done); err != nil {
return fmt.Errorf("subscribe to link updates: %v", err)
}
routeChan := make(chan netlink.RouteUpdate)
if err := netlink.RouteSubscribe(routeChan, done); err != nil {
return fmt.Errorf("subscribe to route updates: %v", err)
@@ -33,6 +38,25 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
case <-ctx.Done():
return ErrStopped
// handle interface state changes
case update := <-linkChan:
if (intfv4 == nil || update.Index != int32(intfv4.Index)) && (intfv6 == nil || update.Index != int32(intfv6.Index)) {
continue
}
switch update.Header.Type {
case syscall.RTM_DELLINK:
log.Infof("Network monitor: monitored interface (%s) is gone", update.Link.Attrs().Name)
go callback()
return nil
case syscall.RTM_NEWLINK:
if (update.IfInfomsg.Flags&syscall.IFF_RUNNING) == 0 && update.Link.Attrs().OperState == netlink.OperDown {
log.Infof("Network monitor: monitored interface (%s) is down.", update.Link.Attrs().Name)
go callback()
return nil
}
}
// handle route changes
case route := <-routeChan:
// default route and main table
@@ -46,7 +70,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
go callback()
return nil
case syscall.RTM_DELROUTE:
if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) {
if intfv4 != nil && route.Gw.Equal(nexthopv4.AsSlice()) || intfv6 != nil && route.Gw.Equal(nexthop6.AsSlice()) {
log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex)
go callback()
return nil

View File

@@ -5,12 +5,11 @@ import (
"fmt"
"net"
"net/netip"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/routemanager"
)
const (
@@ -26,16 +25,20 @@ const (
const interval = 10 * time.Second
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
var neighborv4, neighborv6 *systemops.Neighbor
func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthopv6 netip.Addr, intfv6 *net.Interface, callback func()) error {
var neighborv4, neighborv6 *routemanager.Neighbor
{
initialNeighbors, err := getNeighbors()
if err != nil {
return fmt.Errorf("get neighbors: %w", err)
}
neighborv4 = assignNeighbor(nexthopv4, initialNeighbors)
neighborv6 = assignNeighbor(nexthopv6, initialNeighbors)
if n, ok := initialNeighbors[nexthopv4]; ok {
neighborv4 = &n
}
if n, ok := initialNeighbors[nexthopv6]; ok {
neighborv6 = &n
}
}
log.Debugf("Network monitor: initial IPv4 neighbor: %v, IPv6 neighbor: %v", neighborv4, neighborv6)
@@ -47,7 +50,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
case <-ctx.Done():
return ErrStopped
case <-ticker.C:
if changed(nexthopv4, neighborv4, nexthopv6, neighborv6) {
if changed(nexthopv4, intfv4, neighborv4, nexthopv6, intfv6, neighborv6) {
go callback()
return nil
}
@@ -55,21 +58,13 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
}
}
func assignNeighbor(nexthop systemops.Nexthop, initialNeighbors map[netip.Addr]systemops.Neighbor) *systemops.Neighbor {
if n, ok := initialNeighbors[nexthop.IP]; ok &&
n.State != unreachable &&
n.State != incomplete &&
n.State != tbd {
return &n
}
return nil
}
func changed(
nexthopv4 systemops.Nexthop,
neighborv4 *systemops.Neighbor,
nexthopv6 systemops.Nexthop,
neighborv6 *systemops.Neighbor,
nexthopv4 netip.Addr,
intfv4 *net.Interface,
neighborv4 *routemanager.Neighbor,
nexthopv6 netip.Addr,
intfv6 *net.Interface,
neighborv6 *routemanager.Neighbor,
) bool {
neighbors, err := getNeighbors()
if err != nil {
@@ -86,7 +81,7 @@ func changed(
return false
}
if routeChanged(nexthopv4, nexthopv4.Intf, routes) || routeChanged(nexthopv6, nexthopv6.Intf, routes) {
if routeChanged(nexthopv4, intfv4, routes) || routeChanged(nexthopv6, intfv6, routes) {
return true
}
@@ -94,74 +89,44 @@ func changed(
}
// routeChanged checks if the default routes still point to our nexthop/interface
func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route) bool {
if !nexthop.IP.IsValid() {
func routeChanged(nexthop netip.Addr, intf *net.Interface, routes map[netip.Prefix]routemanager.Route) bool {
if !nexthop.IsValid() {
return false
}
unspec := getUnspecifiedPrefix(nexthop.IP)
defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec)
var unspec netip.Prefix
if nexthop.Is6() {
unspec = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
} else {
unspec = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
}
log.Tracef("network monitor: all default routes:\n%s", strings.Join(defaultRoutes, "\n"))
if !foundMatchingRoute {
logRouteChange(nexthop.IP, intf)
if r, ok := routes[unspec]; ok {
if r.Nexthop != nexthop || compareIntf(r.Interface, intf) != 0 {
intf := "<nil>"
if r.Interface != nil {
intf = r.Interface.Name
}
log.Infof("network monitor: default route changed: %s via %s (%s)", r.Destination, r.Nexthop, intf)
return true
}
} else {
log.Infof("network monitor: default route is gone")
return true
}
return false
}
func getUnspecifiedPrefix(ip netip.Addr) netip.Prefix {
if ip.Is6() {
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
}
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
}
func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) {
var defaultRoutes []string
foundMatchingRoute := false
for _, r := range routes {
if r.Destination == unspec {
routeInfo := formatRouteInfo(r)
defaultRoutes = append(defaultRoutes, routeInfo)
if r.Nexthop == nexthop.IP && compareIntf(r.Interface, intf) == 0 {
foundMatchingRoute = true
log.Debugf("network monitor: found matching default route: %s", routeInfo)
}
}
}
return defaultRoutes, foundMatchingRoute
}
func formatRouteInfo(r systemops.Route) string {
newIntf := "<nil>"
if r.Interface != nil {
newIntf = r.Interface.Name
}
return fmt.Sprintf("Nexthop: %s, Interface: %s", r.Nexthop, newIntf)
}
func logRouteChange(ip netip.Addr, intf *net.Interface) {
oldIntf := "<nil>"
if intf != nil {
oldIntf = intf.Name
}
log.Infof("network monitor: default route for %s (%s) is gone or changed", ip, oldIntf)
}
func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, neighbors map[netip.Addr]systemops.Neighbor) bool {
func neighborChanged(nexthop netip.Addr, neighbor *routemanager.Neighbor, neighbors map[netip.Addr]routemanager.Neighbor) bool {
if neighbor == nil {
return false
}
// TODO: consider non-local nexthops, e.g. on point-to-point interfaces
if n, ok := neighbors[nexthop.IP]; ok {
if n.State == unreachable || n.State == incomplete {
if n, ok := neighbors[nexthop]; ok {
if n.State != reachable && n.State != permanent {
log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State))
return true
} else if n.InterfaceIndex != neighbor.InterfaceIndex {
@@ -185,13 +150,13 @@ func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, ne
return false
}
func getNeighbors() (map[netip.Addr]systemops.Neighbor, error) {
entries, err := systemops.GetNeighbors()
func getNeighbors() (map[netip.Addr]routemanager.Neighbor, error) {
entries, err := routemanager.GetNeighbors()
if err != nil {
return nil, fmt.Errorf("get neighbors: %w", err)
}
neighbours := make(map[netip.Addr]systemops.Neighbor, len(entries))
neighbours := make(map[netip.Addr]routemanager.Neighbor, len(entries))
for _, entry := range entries {
neighbours[entry.IPAddress] = entry
}
@@ -199,13 +164,18 @@ func getNeighbors() (map[netip.Addr]systemops.Neighbor, error) {
return neighbours, nil
}
func getRoutes() ([]systemops.Route, error) {
entries, err := systemops.GetRoutes()
func getRoutes() (map[netip.Prefix]routemanager.Route, error) {
entries, err := routemanager.GetRoutes()
if err != nil {
return nil, fmt.Errorf("get routes: %w", err)
}
return entries, nil
routes := make(map[netip.Prefix]routemanager.Route, len(entries))
for _, entry := range entries {
routes[entry.Destination] = entry
}
return routes, nil
}
func stateFromInt(state uint8) string {

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net"
"net/netip"
"runtime"
"strings"
"sync"
@@ -19,6 +20,7 @@ import (
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version"
@@ -68,6 +70,9 @@ type ConnConfig struct {
NATExternalIPs []string
// UsesBind indicates whether the WireGuard interface is userspace and uses bind.ICEBind
UserspaceBind bool
// RosenpassPubKey is this peer's Rosenpass public key
RosenpassPubKey []byte
// RosenpassPubKey is this peer's RosenpassAddr server address (IP:port)
@@ -98,6 +103,9 @@ type IceCredentials struct {
Pwd string
}
type BeforeAddPeerHookFunc func(connID nbnet.ConnectionID, IP net.IP) error
type AfterRemovePeerHookFunc func(connID nbnet.ConnectionID) error
type Conn struct {
config ConnConfig
mu sync.Mutex
@@ -127,13 +135,30 @@ type Conn struct {
wgProxyFactory *wgproxy.Factory
wgProxy wgproxy.Proxy
remoteModeCh chan ModeMessage
meta meta
adapter iface.TunAdapter
iFaceDiscover stdnet.ExternalIFaceDiscover
sentExtraSrflx bool
remoteEndpoint *net.UDPAddr
remoteConn *ice.Conn
connID nbnet.ConnectionID
beforeAddPeerHooks []nbnet.AddHookFunc
afterRemovePeerHooks []nbnet.RemoveHookFunc
beforeAddPeerHooks []BeforeAddPeerHookFunc
afterRemovePeerHooks []AfterRemovePeerHookFunc
}
// meta holds meta information about a connection
type meta struct {
protoSupport signal.FeaturesSupport
}
// ModeMessage represents a connection mode chosen by the peer
type ModeMessage struct {
// Direct indicates that it decided to use a direct connection
Direct bool
}
// GetConf returns the connection config
@@ -162,6 +187,7 @@ func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.
remoteOffersCh: make(chan OfferAnswer),
remoteAnswerCh: make(chan OfferAnswer),
statusRecorder: statusRecorder,
remoteModeCh: make(chan ModeMessage, 1),
wgProxyFactory: wgProxyFactory,
adapter: adapter,
iFaceDiscover: iFaceDiscover,
@@ -352,6 +378,8 @@ func (conn *Conn) Open(ctx context.Context) error {
remoteWgPort = remoteOfferAnswer.WgListenPort
}
conn.remoteConn = remoteConn
// the ice connection has been established successfully so we are ready to start the proxy
remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort, remoteOfferAnswer.RosenpassPubKey,
remoteOfferAnswer.RosenpassAddr)
@@ -376,11 +404,11 @@ func isRelayCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay
}
func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) {
func (conn *Conn) AddBeforeAddPeerHook(hook BeforeAddPeerHookFunc) {
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
}
func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) {
func (conn *Conn) AddAfterRemovePeerHook(hook AfterRemovePeerHookFunc) {
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
}
@@ -409,6 +437,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
}
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
conn.remoteEndpoint = endpointUdpAddr
log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
conn.connID = nbnet.GenerateConnID()
@@ -455,7 +484,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
log.Warnf("unable to save peer's state, got error: %v", err)
}
_, ipNet, err := net.ParseCIDR(conn.config.WgConfig.AllowedIps)
_, ipNet, err := net.ParseCIDR(strings.Split(conn.config.WgConfig.AllowedIps, ",")[0])
if err != nil {
return nil, err
}
@@ -594,39 +623,40 @@ func (conn *Conn) SetSendSignalMessage(handler func(message *sProto.Message) err
// onICECandidate is a callback attached to an ICE Agent to receive new local connection candidates
// and then signals them to the remote peer
func (conn *Conn) onICECandidate(candidate ice.Candidate) {
// nil means candidate gathering has been ended
if candidate == nil {
return
if candidate != nil {
// TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored
log.Debugf("discovered local candidate %s", candidate.String())
go func() {
err := conn.signalCandidate(candidate)
if err != nil {
log.Errorf("failed signaling candidate to the remote peer %s %s", conn.config.Key, err)
}
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
if !conn.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
relatedAdd := candidate.RelatedAddress()
extraSrflx, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
Network: candidate.NetworkType().String(),
Address: candidate.Address(),
Port: relatedAdd.Port,
Component: candidate.Component(),
RelAddr: relatedAdd.Address,
RelPort: relatedAdd.Port,
})
if err != nil {
log.Errorf("failed creating extra server reflexive candidate %s", err)
return
}
err = conn.signalCandidate(extraSrflx)
if err != nil {
log.Errorf("failed signaling the extra server reflexive candidate to the remote peer %s: %s", conn.config.Key, err)
return
}
conn.sentExtraSrflx = true
}
}()
}
// TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored
log.Debugf("discovered local candidate %s", candidate.String())
go func() {
err := conn.signalCandidate(candidate)
if err != nil {
log.Errorf("failed signaling candidate to the remote peer %s %s", conn.config.Key, err)
}
}()
if !conn.shouldSendExtraSrflxCandidate(candidate) {
return
}
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
extraSrflx, err := extraSrflxCandidate(candidate)
if err != nil {
log.Errorf("failed creating extra server reflexive candidate %s", err)
return
}
conn.sentExtraSrflx = true
go func() {
err = conn.signalCandidate(extraSrflx)
if err != nil {
log.Errorf("failed signaling the extra server reflexive candidate to the remote peer %s: %s", conn.config.Key, err)
}
}()
}
func (conn *Conn) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
@@ -761,6 +791,10 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMa
return
}
if candidateViaRoutes(candidate, haRoutes) {
return
}
err := conn.agent.AddRemoteCandidate(candidate)
if err != nil {
log.Errorf("error while handling remote candidate from peer %s", conn.config.Key)
@@ -773,21 +807,36 @@ func (conn *Conn) GetKey() string {
return conn.config.Key
}
func (conn *Conn) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
if !conn.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
return true
// RegisterProtoSupportMeta register supported proto message in the connection metadata
func (conn *Conn) RegisterProtoSupportMeta(support []uint32) {
protoSupport := signal.ParseFeaturesSupported(support)
conn.meta.protoSupport = protoSupport
}
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
var routePrefixes []netip.Prefix
for _, routes := range clientRoutes {
if len(routes) > 0 && routes[0] != nil {
routePrefixes = append(routePrefixes, routes[0].Network)
}
}
addr, err := netip.ParseAddr(candidate.Address())
if err != nil {
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
return false
}
for _, prefix := range routePrefixes {
// default route is
if prefix.Bits() == 0 {
continue
}
if prefix.Contains(addr) {
log.Debugf("Ignoring candidate [%s], its address is part of routed network %s", candidate.String(), prefix)
return true
}
}
return false
}
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
relatedAdd := candidate.RelatedAddress()
return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
Network: candidate.NetworkType().String(),
Address: candidate.Address(),
Port: relatedAdd.Port,
Component: candidate.Component(),
RelAddr: relatedAdd.Address,
RelPort: relatedAdd.Port,
})
}

View File

@@ -36,7 +36,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
}
func TestConn_GetKey(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -51,7 +51,7 @@ func TestConn_GetKey(t *testing.T) {
}
func TestConn_OnRemoteOffer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -88,7 +88,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
}
func TestConn_OnRemoteAnswer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -124,7 +124,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait()
}
func TestConn_Status(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -154,7 +154,7 @@ func TestConn_Status(t *testing.T) {
}
func TestConn_Close(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()

View File

@@ -2,17 +2,14 @@ package peer
import (
"errors"
"net/netip"
"sync"
"time"
"golang.org/x/exp/maps"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/management/domain"
)
// State contains the latest state of a peer
@@ -40,25 +37,25 @@ type State struct {
// AddRoute add a single route to routes map
func (s *State) AddRoute(network string) {
s.Mux.Lock()
defer s.Mux.Unlock()
if s.routes == nil {
s.routes = make(map[string]struct{})
}
s.routes[network] = struct{}{}
s.Mux.Unlock()
}
// SetRoutes set state routes
func (s *State) SetRoutes(routes map[string]struct{}) {
s.Mux.Lock()
defer s.Mux.Unlock()
s.routes = routes
s.Mux.Unlock()
}
// DeleteRoute removes a route from the network amp
func (s *State) DeleteRoute(network string) {
s.Mux.Lock()
defer s.Mux.Unlock()
delete(s.routes, network)
s.Mux.Unlock()
}
// GetRoutes return routes map
@@ -71,6 +68,7 @@ func (s *State) GetRoutes() map[string]struct{} {
// LocalPeerState contains the latest state of the local peer
type LocalPeerState struct {
IP string
IP6 string
PubKey string
KernelInterface bool
FQDN string
@@ -120,23 +118,22 @@ type FullStatus struct {
// Status holds a state of peers, signal, management connections and relays
type Status struct {
mux sync.Mutex
peers map[string]State
changeNotify map[string]chan struct{}
signalState bool
signalError error
managementState bool
managementError error
relayStates []relay.ProbeResult
localPeer LocalPeerState
offlinePeers []State
mgmAddress string
signalAddress string
notifier *notifier
rosenpassEnabled bool
rosenpassPermissive bool
nsGroupStates []NSGroupState
resolvedDomainsStates map[domain.Domain][]netip.Prefix
mux sync.Mutex
peers map[string]State
changeNotify map[string]chan struct{}
signalState bool
signalError error
managementState bool
managementError error
relayStates []relay.ProbeResult
localPeer LocalPeerState
offlinePeers []State
mgmAddress string
signalAddress string
notifier *notifier
rosenpassEnabled bool
rosenpassPermissive bool
nsGroupStates []NSGroupState
// To reduce the number of notification invocation this bool will be true when need to call the notification
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
@@ -147,12 +144,11 @@ type Status struct {
// NewRecorder returns a new Status instance
func NewRecorder(mgmAddress string) *Status {
return &Status{
peers: make(map[string]State),
changeNotify: make(map[string]chan struct{}),
offlinePeers: make([]State, 0),
notifier: newNotifier(),
mgmAddress: mgmAddress,
resolvedDomainsStates: make(map[domain.Domain][]netip.Prefix),
peers: make(map[string]State),
changeNotify: make(map[string]chan struct{}),
offlinePeers: make([]State, 0),
notifier: newNotifier(),
mgmAddress: mgmAddress,
}
}
@@ -193,7 +189,7 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
state, ok := d.peers[peerPubKey]
if !ok {
return State{}, iface.ErrPeerNotFound
return State{}, errors.New("peer not found")
}
return state, nil
}
@@ -434,18 +430,6 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
d.nsGroupStates = dnsStates
}
func (d *Status) UpdateResolvedDomainsStates(domain domain.Domain, prefixes []netip.Prefix) {
d.mux.Lock()
defer d.mux.Unlock()
d.resolvedDomainsStates[domain] = prefixes
}
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
d.mux.Lock()
defer d.mux.Unlock()
delete(d.resolvedDomainsStates, domain)
}
func (d *Status) GetRosenpassState() RosenpassState {
return RosenpassState{
d.rosenpassEnabled,
@@ -510,12 +494,6 @@ func (d *Status) GetDNSStates() []NSGroupState {
return d.nsGroupStates
}
func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
d.mux.Lock()
defer d.mux.Unlock()
return maps.Clone(d.resolvedDomainsStates)
}
// GetFullStatus gets full status
func (d *Status) GetFullStatus() FullStatus {
d.mux.Lock()

View File

@@ -3,20 +3,19 @@ package routemanager
import (
"context"
"fmt"
"net"
"net/netip"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
const minRangeBits = 7
type routerPeerStatus struct {
connected bool
relayed bool
@@ -29,42 +28,34 @@ type routesUpdate struct {
routes []*route.Route
}
// RouteHandler defines the interface for handling routes
type RouteHandler interface {
String() string
AddRoute(ctx context.Context) error
RemoveRoute() error
AddAllowedIPs(peerKey string) error
RemoveAllowedIPs() error
}
type clientNetwork struct {
ctx context.Context
cancel context.CancelFunc
stop context.CancelFunc
statusRecorder *peer.Status
wgInterface *iface.WGIface
routes map[route.ID]*route.Route
routeUpdate chan routesUpdate
peerStateUpdate chan struct{}
routePeersNotifiers map[string]chan struct{}
currentChosen *route.Route
handler RouteHandler
chosenRoute *route.Route
chosenIP *net.IP
network netip.Prefix
updateSerial uint64
}
func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface *iface.WGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork {
ctx, cancel := context.WithCancel(ctx)
client := &clientNetwork{
ctx: ctx,
cancel: cancel,
stop: cancel,
statusRecorder: statusRecorder,
wgInterface: wgInterface,
routes: make(map[route.ID]*route.Route),
routePeersNotifiers: make(map[string]chan struct{}),
routeUpdate: make(chan routesUpdate),
peerStateUpdate: make(chan struct{}),
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder),
network: network,
}
return client
}
@@ -96,8 +87,8 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
// * Metric: Routes with lower metrics (better) are prioritized.
// * Non-relayed: Routes without relays are preferred.
// * Direct connections: Routes with direct peer connections are favored.
// * Latency: Routes with lower latency are prioritized.
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
// * Latency: Routes with lower latency are prioritized.
//
// It returns the ID of the selected optimal route.
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID {
@@ -106,8 +97,8 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
currScore := float64(0)
currID := route.ID("")
if c.currentChosen != nil {
currID = c.currentChosen.ID
if c.chosenRoute != nil {
currID = c.chosenRoute.ID
}
for _, r := range c.routes {
@@ -161,18 +152,18 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
peers = append(peers, r.Peer)
}
log.Warnf("The network [%v] has not been assigned a routing peer as no peers from the list %s are currently connected", c.handler, peers)
log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers)
case chosen != currID:
// we compare the current score + 10ms to the chosen score to avoid flapping between routes
if currScore != 0 && currScore+0.01 > chosenScore {
log.Debugf("Keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore)
log.Debugf("keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore)
return currID
}
var p string
if rt := c.routes[chosen]; rt != nil {
p = rt.Peer
}
log.Infof("New chosen route is %s with peer %s with score %f for network [%v]", chosen, p, chosenScore, c.handler)
log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, p, chosenScore, c.network)
}
return chosen
@@ -206,103 +197,109 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() {
}
}
func (c *clientNetwork) removeRouteFromWireguardPeer() error {
c.removeStateRoute()
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
state, err := c.statusRecorder.GetPeer(peerKey)
if err != nil {
return fmt.Errorf("get peer state: %v", err)
}
if err := c.handler.RemoveAllowedIPs(); err != nil {
return fmt.Errorf("remove allowed IPs: %w", err)
state.DeleteRoute(c.network.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
if state.ConnStatus != peer.StatusConnected {
return nil
}
err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String())
if err != nil {
return fmt.Errorf("remove allowed IP %s removed for peer %s, err: %v",
c.network, c.chosenRoute.Peer, err)
}
return nil
}
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
if c.currentChosen == nil {
return nil
}
if c.chosenRoute != nil {
// TODO IPv6 (pass wgInterface)
if err := removeVPNRoute(c.network, c.getAsInterface()); err != nil {
return fmt.Errorf("remove route %s from system, err: %v", c.network, err)
}
var merr *multierror.Error
if err := c.removeRouteFromWireguardPeer(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil {
return fmt.Errorf("remove route: %v", err)
}
}
if err := c.handler.RemoveRoute(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
return nil
}
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
routerPeerStatuses := c.getRouterPeerStatuses()
newChosenID := c.getBestRouteFromStatuses(routerPeerStatuses)
chosen := c.getBestRouteFromStatuses(routerPeerStatuses)
// If no route is chosen, remove the route from the peer and system
if newChosenID == "" {
if chosen == "" {
if err := c.removeRouteFromPeerAndSystem(); err != nil {
return fmt.Errorf("remove route for peer %s: %w", c.currentChosen.Peer, err)
return fmt.Errorf("remove route from peer and system: %v", err)
}
c.currentChosen = nil
c.chosenRoute = nil
return nil
}
// If the chosen route is the same as the current route, do nothing
if c.currentChosen != nil && c.currentChosen.ID == newChosenID &&
c.currentChosen.IsEqual(c.routes[newChosenID]) {
return nil
if c.chosenRoute != nil && c.chosenRoute.ID == chosen {
if c.chosenRoute.IsEqual(c.routes[chosen]) {
return nil
}
}
if c.currentChosen == nil {
// If they were not previously assigned to another peer, add routes to the system first
if err := c.handler.AddRoute(c.ctx); err != nil {
return fmt.Errorf("add route: %w", err)
if c.chosenRoute != nil {
// If a previous route exists, remove it from the peer
if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil {
return fmt.Errorf("remove route from peer: %v", err)
}
} else {
// Otherwise, remove the allowed IPs from the previous peer first
if err := c.removeRouteFromWireguardPeer(); err != nil {
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
// TODO recheck IPv6
gwAddr := c.wgInterface.Address().IP
c.chosenIP = &gwAddr
if c.network.Addr().Is6() {
if c.wgInterface.Address6() == nil {
return fmt.Errorf("Could not assign IPv6 route %s for peer %s because no IPv6 address is assigned",
c.network.String(), c.wgInterface.Address().IP.String())
}
c.chosenIP = &c.wgInterface.Address6().IP
}
// otherwise add the route to the system
if err := addVPNRoute(c.network, c.getAsInterface()); err != nil {
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
c.network.String(), c.chosenIP.String(), err)
}
}
c.currentChosen = c.routes[newChosenID]
c.chosenRoute = c.routes[chosen]
if err := c.handler.AddAllowedIPs(c.currentChosen.Peer); err != nil {
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
state, err := c.statusRecorder.GetPeer(c.chosenRoute.Peer)
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
} else {
state.AddRoute(c.network.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
}
c.addStateRoute()
if err := c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()); err != nil {
log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v",
c.network, c.chosenRoute.Peer, err)
}
return nil
}
func (c *clientNetwork) addStateRoute() {
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
return
}
state.AddRoute(c.handler.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
}
func (c *clientNetwork) removeStateRoute() {
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
return
}
state.DeleteRoute(c.handler.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
}
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
go func() {
c.routeUpdate <- update
@@ -333,23 +330,24 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
for {
select {
case <-c.ctx.Done():
log.Debugf("Stopping watcher for network [%v]", c.handler)
if err := c.removeRouteFromPeerAndSystem(); err != nil {
log.Errorf("Failed to remove routes for [%v]: %v", c.handler, err)
log.Debugf("stopping watcher for network %s", c.network)
err := c.removeRouteFromPeerAndSystem()
if err != nil {
log.Errorf("Couldn't remove route from peer and system for network %s: %v", c.network, err)
}
return
case <-c.peerStateUpdate:
err := c.recalculateRouteAndUpdatePeerAndSystem()
if err != nil {
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
log.Errorf("Couldn't recalculate route and update peer and system: %v", err)
}
case update := <-c.routeUpdate:
if update.updateSerial < c.updateSerial {
log.Warnf("Received a routes update with smaller serial number (%d -> %d), ignoring it", c.updateSerial, update.updateSerial)
log.Warnf("Received a routes update with smaller serial number, ignoring it")
continue
}
log.Debugf("Received a new client network route update for [%v]", c.handler)
log.Debugf("Received a new client network route update for %s", c.network)
c.handleUpdate(update)
@@ -357,7 +355,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
err := c.recalculateRouteAndUpdatePeerAndSystem()
if err != nil {
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
log.Errorf("Couldn't recalculate route and update peer and system for network %s: %v", c.network, err)
}
c.startPeersStatusChangeWatcher()
@@ -365,9 +363,14 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
}
}
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status) RouteHandler {
if rt.IsDynamic() {
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder)
func (c *clientNetwork) getAsInterface() *net.Interface {
intf, err := net.InterfaceByName(c.wgInterface.Name())
if err != nil {
log.Warnf("Couldn't get interface by name %s: %v", c.wgInterface.Name(), err)
intf = &net.Interface{
Name: c.wgInterface.Name(),
}
}
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
return intf
}

View File

@@ -5,7 +5,6 @@ import (
"testing"
"time"
"github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/route"
)
@@ -341,9 +340,9 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
// create new clientNetwork
client := &clientNetwork{
handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil),
routes: tc.existingRoutes,
currentChosen: currentRoute,
network: netip.MustParsePrefix("192.168.0.0/24"),
routes: tc.existingRoutes,
chosenRoute: currentRoute,
}
chosenRoute := client.getBestRouteFromStatuses(tc.statuses)

View File

@@ -1,378 +0,0 @@
package dynamic
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"sync"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
const (
DefaultInterval = time.Minute
minInterval = 2 * time.Second
failureInterval = 5 * time.Second
addAllowedIP = "add allowed IP %s: %w"
)
type domainMap map[domain.Domain][]netip.Prefix
type resolveResult struct {
domain domain.Domain
prefix netip.Prefix
err error
}
type Route struct {
route *route.Route
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
interval time.Duration
dynamicDomains domainMap
mu sync.Mutex
currentPeerKey string
cancel context.CancelFunc
statusRecorder *peer.Status
}
func NewRoute(
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
interval time.Duration,
statusRecorder *peer.Status,
) *Route {
return &Route{
route: rt,
routeRefCounter: routeRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter,
interval: interval,
dynamicDomains: domainMap{},
statusRecorder: statusRecorder,
}
}
func (r *Route) String() string {
s, err := r.route.Domains.String()
if err != nil {
return r.route.Domains.PunycodeString()
}
return s
}
func (r *Route) AddRoute(ctx context.Context) error {
r.mu.Lock()
defer r.mu.Unlock()
if r.cancel != nil {
r.cancel()
}
ctx, r.cancel = context.WithCancel(ctx)
go r.startResolver(ctx)
return nil
}
// RemoveRoute will stop the dynamic resolver and remove all dynamic routes.
// It doesn't touch allowed IPs, these should be removed separately and before calling this method.
func (r *Route) RemoveRoute() error {
r.mu.Lock()
defer r.mu.Unlock()
if r.cancel != nil {
r.cancel()
}
var merr *multierror.Error
for domain, prefixes := range r.dynamicDomains {
for _, prefix := range prefixes {
if _, err := r.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %w", prefix, err))
}
}
log.Debugf("Removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
r.statusRecorder.DeleteResolvedDomainsStates(domain)
}
r.dynamicDomains = domainMap{}
return nberrors.FormatErrorOrNil(merr)
}
func (r *Route) AddAllowedIPs(peerKey string) error {
r.mu.Lock()
defer r.mu.Unlock()
var merr *multierror.Error
for domain, domainPrefixes := range r.dynamicDomains {
for _, prefix := range domainPrefixes {
if err := r.incrementAllowedIP(domain, prefix, peerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf(addAllowedIP, prefix, err))
}
}
}
r.currentPeerKey = peerKey
return nberrors.FormatErrorOrNil(merr)
}
func (r *Route) RemoveAllowedIPs() error {
r.mu.Lock()
defer r.mu.Unlock()
var merr *multierror.Error
for _, domainPrefixes := range r.dynamicDomains {
for _, prefix := range domainPrefixes {
if _, err := r.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %w", prefix, err))
}
}
}
r.currentPeerKey = ""
return nberrors.FormatErrorOrNil(merr)
}
func (r *Route) startResolver(ctx context.Context) {
log.Debugf("Starting dynamic route resolver for domains [%v]", r)
interval := r.interval
if interval < minInterval {
interval = minInterval
log.Warnf("Dynamic route resolver interval %s is too low, setting to minimum value %s", r.interval, minInterval)
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
if err := r.update(ctx); err != nil {
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
if interval > failureInterval {
ticker.Reset(failureInterval)
}
}
for {
select {
case <-ctx.Done():
log.Debugf("Stopping dynamic route resolver for domains [%v]", r)
return
case <-ticker.C:
if err := r.update(ctx); err != nil {
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
// Use a lower ticker interval if the update fails
if interval > failureInterval {
ticker.Reset(failureInterval)
}
} else if interval > failureInterval {
// Reset to the original interval if the update succeeds
ticker.Reset(interval)
}
}
}
}
func (r *Route) update(ctx context.Context) error {
if resolved, err := r.resolveDomains(); err != nil {
return fmt.Errorf("resolve domains: %w", err)
} else if err := r.updateDynamicRoutes(ctx, resolved); err != nil {
return fmt.Errorf("update dynamic routes: %w", err)
}
return nil
}
func (r *Route) resolveDomains() (domainMap, error) {
results := make(chan resolveResult)
go r.resolve(results)
resolved := domainMap{}
var merr *multierror.Error
for result := range results {
if result.err != nil {
merr = multierror.Append(merr, result.err)
} else {
resolved[result.domain] = append(resolved[result.domain], result.prefix)
}
}
return resolved, nberrors.FormatErrorOrNil(merr)
}
func (r *Route) resolve(results chan resolveResult) {
var wg sync.WaitGroup
for _, d := range r.route.Domains {
wg.Add(1)
go func(domain domain.Domain) {
defer wg.Done()
ips, err := net.LookupIP(string(domain))
if err != nil {
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
return
}
for _, ip := range ips {
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {
results <- resolveResult{domain: domain, err: fmt.Errorf("get prefix from IP %s: %w", ip.String(), err)}
return
}
results <- resolveResult{domain: domain, prefix: prefix}
}
}(d)
}
wg.Wait()
close(results)
}
func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) error {
r.mu.Lock()
defer r.mu.Unlock()
if ctx.Err() != nil {
return ctx.Err()
}
var merr *multierror.Error
for domain, newPrefixes := range newDomains {
oldPrefixes := r.dynamicDomains[domain]
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
addedPrefixes, err := r.addRoutes(domain, toAdd)
if err != nil {
merr = multierror.Append(merr, err)
} else if len(addedPrefixes) > 0 {
log.Debugf("Added dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", addedPrefixes), " ", ", "))
}
removedPrefixes, err := r.removeRoutes(toRemove)
if err != nil {
merr = multierror.Append(merr, err)
} else if len(removedPrefixes) > 0 {
log.Debugf("Removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", removedPrefixes), " ", ", "))
}
updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes)
r.dynamicDomains[domain] = updatedPrefixes
r.statusRecorder.UpdateResolvedDomainsStates(domain, updatedPrefixes)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *Route) addRoutes(domain domain.Domain, prefixes []netip.Prefix) ([]netip.Prefix, error) {
var addedPrefixes []netip.Prefix
var merr *multierror.Error
for _, prefix := range prefixes {
if _, err := r.routeRefCounter.Increment(prefix, nil); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add dynamic route for IP %s: %w", prefix, err))
continue
}
if r.currentPeerKey != "" {
if err := r.incrementAllowedIP(domain, prefix, r.currentPeerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf(addAllowedIP, prefix, err))
}
}
addedPrefixes = append(addedPrefixes, prefix)
}
return addedPrefixes, merr.ErrorOrNil()
}
func (r *Route) removeRoutes(prefixes []netip.Prefix) ([]netip.Prefix, error) {
if r.route.KeepRoute {
return nil, nil
}
var removedPrefixes []netip.Prefix
var merr *multierror.Error
for _, prefix := range prefixes {
if _, err := r.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %w", prefix, err))
}
if r.currentPeerKey != "" {
if _, err := r.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %w", prefix, err))
}
}
removedPrefixes = append(removedPrefixes, prefix)
}
return removedPrefixes, merr.ErrorOrNil()
}
func (r *Route) incrementAllowedIP(domain domain.Domain, prefix netip.Prefix, peerKey string) error {
if ref, err := r.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
return fmt.Errorf(addAllowedIP, prefix, err)
} else if ref.Count > 1 && ref.Out != peerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
prefix.Addr(),
domain.SafeString(),
ref.Out,
)
}
return nil
}
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
prefixSet := make(map[netip.Prefix]bool)
for _, prefix := range oldPrefixes {
prefixSet[prefix] = false
}
for _, prefix := range newPrefixes {
if _, exists := prefixSet[prefix]; exists {
prefixSet[prefix] = true
} else {
toAdd = append(toAdd, prefix)
}
}
for prefix, inUse := range prefixSet {
if !inUse {
toRemove = append(toRemove, prefix)
}
}
return
}
func combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes []netip.Prefix) []netip.Prefix {
prefixSet := make(map[netip.Prefix]struct{})
for _, prefix := range oldPrefixes {
prefixSet[prefix] = struct{}{}
}
for _, prefix := range removedPrefixes {
delete(prefixSet, prefix)
}
for _, prefix := range addedPrefixes {
prefixSet[prefix] = struct{}{}
}
var combinedPrefixes []netip.Prefix
for prefix := range prefixSet {
combinedPrefixes = append(combinedPrefixes, prefix)
}
return combinedPrefixes
}

View File

@@ -2,23 +2,18 @@ package routemanager
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"net/url"
"runtime"
"sync"
"time"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
@@ -26,85 +21,51 @@ import (
"github.com/netbirdio/netbird/version"
)
var defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
// nolint:unused
var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
// Manager is a route manager interface
type Manager interface {
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
TriggerSelection(route.HAMap)
GetRouteSelector() *routeselector.RouteSelector
SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string
ResetV6Routes()
EnableServerRouter(firewall firewall.Manager) error
Stop()
}
// DefaultManager is the default instance of a route manager
type DefaultManager struct {
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
clientNetworks map[route.HAUniqueID]*clientNetwork
routeSelector *routeselector.RouteSelector
serverRouter serverRouter
sysOps *systemops.SysOps
statusRecorder *peer.Status
wgInterface *iface.WGIface
pubKey string
notifier *notifier
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
dnsRouteInterval time.Duration
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
clientNetworks map[route.HAUniqueID]*clientNetwork
routeSelector *routeselector.RouteSelector
serverRouter serverRouter
statusRecorder *peer.Status
wgInterface *iface.WGIface
pubKey string
notifier *notifier
}
func NewManager(
ctx context.Context,
pubKey string,
dnsRouteInterval time.Duration,
wgInterface *iface.WGIface,
statusRecorder *peer.Status,
initialRoutes []*route.Route,
) *DefaultManager {
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
sysOps := systemops.NewSysOps(wgInterface)
dm := &DefaultManager{
ctx: mCTX,
stop: cancel,
dnsRouteInterval: dnsRouteInterval,
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
routeSelector: routeselector.NewRouteSelector(),
sysOps: sysOps,
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
notifier: newNotifier(),
ctx: mCTX,
stop: cancel,
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
routeSelector: routeselector.NewRouteSelector(),
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
notifier: newNotifier(),
}
dm.routeRefCounter = refcounter.New(
func(prefix netip.Prefix, _ any) (any, error) {
return nil, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface())
},
func(prefix netip.Prefix, _ any) error {
return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface())
},
)
dm.allowedIPsRefCounter = refcounter.New(
func(prefix netip.Prefix, peerKey string) (string, error) {
// save peerKey to use it in the remove function
return peerKey, wgInterface.AddAllowedIP(peerKey, prefix.String())
},
func(prefix netip.Prefix, peerKey string) error {
if err := wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil {
if !errors.Is(err, iface.ErrPeerNotFound) && !errors.Is(err, iface.ErrAllowedIPNotFound) {
return err
}
log.Tracef("Remove allowed IPs %s for %s: %v", prefix, peerKey, err)
}
return nil
},
)
if runtime.GOOS == "android" {
cr := dm.clientRoutes(initialRoutes)
dm.notifier.setInitialClientRoutes(cr)
@@ -113,12 +74,12 @@ func NewManager(
}
// Init sets up the routing
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
if nbnet.CustomRoutingDisabled() {
return nil, nil, nil
}
if err := m.sysOps.CleanupRouting(); err != nil {
if err := cleanupRouting(); err != nil {
log.Warnf("Failed cleaning up routing: %v", err)
}
@@ -126,7 +87,7 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
signalAddress := m.statusRecorder.GetSignalState().URL
ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress})
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips)
beforePeerHook, afterPeerHook, err := setupRouting(ips, m.wgInterface)
if err != nil {
return nil, nil, fmt.Errorf("setup routing: %w", err)
}
@@ -150,19 +111,8 @@ func (m *DefaultManager) Stop() {
m.serverRouter.cleanUp()
}
if m.routeRefCounter != nil {
if err := m.routeRefCounter.Flush(); err != nil {
log.Errorf("Error flushing route ref counter: %v", err)
}
}
if m.allowedIPsRefCounter != nil {
if err := m.allowedIPsRefCounter.Flush(); err != nil {
log.Errorf("Error flushing allowed IPs ref counter: %v", err)
}
}
if !nbnet.CustomRoutingDisabled() {
if err := m.sysOps.CleanupRouting(); err != nil {
if err := cleanupRouting(); err != nil {
log.Errorf("Error cleaning up routing: %v", err)
} else {
log.Info("Routing cleanup complete")
@@ -199,6 +149,18 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
}
}
// ResetV6Routes deletes all IPv6 routes (necessary if IPv6 address changes).
// It is expected that UpdateRoute is called afterwards to recreate the routing table.
func (m *DefaultManager) ResetV6Routes() {
for id, client := range m.clientNetworks {
if client.network.Addr().Is6() {
log.Debugf("stopping client network watcher due to IPv6 address change, %s", id)
client.stop()
delete(m.clientNetworks, id)
}
}
}
// SetRouteChangeListener set RouteListener for route change notifier
func (m *DefaultManager) SetRouteChangeListener(listener listener.NetworkChangeListener) {
m.notifier.setListener(listener)
@@ -236,7 +198,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
continue
}
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
@@ -248,7 +210,7 @@ func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) {
for id, client := range m.clientNetworks {
if _, ok := networks[id]; !ok {
log.Debugf("Stopping client network watcher, %s", id)
client.cancel()
client.stop()
delete(m.clientNetworks, id)
}
}
@@ -261,7 +223,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id]
if !found {
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
}
@@ -279,7 +241,7 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
ownNetworkIDs := make(map[route.HAUniqueID]bool)
for _, newRoute := range newRoutes {
haID := newRoute.GetHAUniqueID()
haID := route.GetHAUniqueID(newRoute)
if newRoute.Peer == m.pubKey {
ownNetworkIDs[haID] = true
// only linux is supported for now
@@ -292,9 +254,9 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
}
for _, newRoute := range newRoutes {
haID := newRoute.GetHAUniqueID()
haID := route.GetHAUniqueID(newRoute)
if !ownNetworkIDs[haID] {
if !isRouteSupported(newRoute) {
if !isPrefixSupported(newRoute.Network) {
continue
}
newClientRoutesIDMap[haID] = append(newClientRoutesIDMap[haID], newRoute)
@@ -306,23 +268,23 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
_, crMap := m.classifyRoutes(initialRoutes)
rs := make([]*route.Route, 0, len(crMap))
rs := make([]*route.Route, 0)
for _, routes := range crMap {
rs = append(rs, routes...)
}
return rs
}
func isRouteSupported(route *route.Route) bool {
if !nbnet.CustomRoutingDisabled() || route.IsDynamic() {
func isPrefixSupported(prefix netip.Prefix) bool {
if !nbnet.CustomRoutingDisabled() {
return true
}
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported
// we skip this prefix management
if route.Network.Bits() <= vars.MinRangeBits {
if prefix.Bits() <= minRangeBits {
log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix",
version.NetbirdVersion(), route.Network)
version.NetbirdVersion(), prefix)
return false
}
return true

View File

@@ -36,6 +36,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
serverRoutesExpected int
clientNetworkWatchersExpected int
clientNetworkWatchersExpectedAllowed int
isV6 bool
}{
{
name: "Should create 2 client networks",
@@ -65,6 +66,35 @@ func TestManagerUpdateRoutes(t *testing.T) {
inputSerial: 1,
clientNetworkWatchersExpected: 2,
},
{
name: "Should create 2 client networks (IPv6)",
inputInitRoutes: []*route.Route{},
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeB",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("2001:db8::7890:abcd/128"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
clientNetworkWatchersExpected: 2,
isV6: true,
},
{
name: "Should Create 2 Server Routes",
inputRoutes: []*route.Route{
@@ -93,6 +123,34 @@ func TestManagerUpdateRoutes(t *testing.T) {
serverRoutesExpected: 2,
clientNetworkWatchersExpected: 0,
},
{
name: "Should Create 2 Server Routes (IPv6)",
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeB",
Peer: localPeerKey,
Network: netip.MustParsePrefix("2001:db8::7890:abcd/128"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
serverRoutesExpected: 2,
clientNetworkWatchersExpected: 0,
},
{
name: "Should Create 1 Route For Client And Server",
inputRoutes: []*route.Route{
@@ -121,6 +179,84 @@ func TestManagerUpdateRoutes(t *testing.T) {
serverRoutesExpected: 1,
clientNetworkWatchersExpected: 1,
},
{
name: "Should Create 1 Route For Client And Server (IPv6)",
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeB",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("2001:db8::7890:abcd/128"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
serverRoutesExpected: 1,
clientNetworkWatchersExpected: 1,
isV6: true,
},
{
name: "Should Create 1 Route For Client And Server for each IP version",
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.30.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeB",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("8.8.9.9/32"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeB",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("2001:db8::7890:abcd/128"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
serverRoutesExpected: 2,
clientNetworkWatchersExpected: 2,
isV6: true,
},
{
name: "Should Create 1 Route For Client and Skip Server Route On Empty Server Router",
inputRoutes: []*route.Route{
@@ -150,6 +286,36 @@ func TestManagerUpdateRoutes(t *testing.T) {
serverRoutesExpected: 0,
clientNetworkWatchersExpected: 1,
},
{
name: "Should Create 1 Route For Client and Skip Server Route On Empty Server Router (IPv6)",
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeB",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("2001:db8::7890:abcd/128"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
removeSrvRouter: true,
serverRoutesExpected: 0,
clientNetworkWatchersExpected: 1,
isV6: true,
},
{
name: "Should Create 1 HA Route and 1 Standalone",
inputRoutes: []*route.Route{
@@ -187,6 +353,44 @@ func TestManagerUpdateRoutes(t *testing.T) {
inputSerial: 1,
clientNetworkWatchersExpected: 2,
},
{
name: "Should Create 1 HA Route and 1 Standalone (IPv6)",
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeA",
Peer: remotePeerKey2,
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "c",
NetID: "routeB",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("2001:db8::7890:abcd/128"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
clientNetworkWatchersExpected: 2,
isV6: true,
},
{
name: "No Small Client Route Should Be Added",
inputRoutes: []*route.Route{
@@ -205,6 +409,25 @@ func TestManagerUpdateRoutes(t *testing.T) {
clientNetworkWatchersExpected: 0,
clientNetworkWatchersExpectedAllowed: 1,
},
{
name: "No Small Client Route Should Be Added (IPv6)",
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("::/0"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
clientNetworkWatchersExpected: 0,
clientNetworkWatchersExpectedAllowed: 1,
isV6: true,
},
{
name: "Remove 1 Client Route",
inputInitRoutes: []*route.Route{
@@ -244,6 +467,46 @@ func TestManagerUpdateRoutes(t *testing.T) {
inputSerial: 1,
clientNetworkWatchersExpected: 1,
},
{
name: "Remove 1 Client Route (IPv6)",
inputInitRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeB",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("2001:db8::abcd:7890/128"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
clientNetworkWatchersExpected: 1,
isV6: true,
},
{
name: "Update Route to HA",
inputInitRoutes: []*route.Route{
@@ -293,6 +556,56 @@ func TestManagerUpdateRoutes(t *testing.T) {
inputSerial: 1,
clientNetworkWatchersExpected: 1,
},
{
name: "Update Route to HA (IPv6)",
inputInitRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeB",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("2001:db8::abcd:7890/128"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeA",
Peer: remotePeerKey2,
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
clientNetworkWatchersExpected: 1,
isV6: true,
},
{
name: "Remove Client Routes",
inputInitRoutes: []*route.Route{
@@ -321,6 +634,35 @@ func TestManagerUpdateRoutes(t *testing.T) {
inputSerial: 1,
clientNetworkWatchersExpected: 0,
},
{
name: "Remove Client Routes (IPv6)",
inputInitRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeB",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("2001:db8::abcd:7890/128"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputRoutes: []*route.Route{},
inputSerial: 1,
clientNetworkWatchersExpected: 0,
isV6: true,
},
{
name: "Remove All Routes",
inputInitRoutes: []*route.Route{
@@ -350,6 +692,36 @@ func TestManagerUpdateRoutes(t *testing.T) {
serverRoutesExpected: 0,
clientNetworkWatchersExpected: 0,
},
{
name: "Remove All Routes (IPv6)",
inputInitRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeB",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("2001:db8::abcd:7890/128"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputRoutes: []*route.Route{},
inputSerial: 1,
serverRoutesExpected: 0,
clientNetworkWatchersExpected: 0,
isV6: true,
},
{
name: "HA server should not register routes from the same HA group",
inputRoutes: []*route.Route{
@@ -398,16 +770,74 @@ func TestManagerUpdateRoutes(t *testing.T) {
serverRoutesExpected: 2,
clientNetworkWatchersExpected: 1,
},
{
name: "HA server should not register routes from the same HA group (IPv6)",
inputRoutes: []*route.Route{
{
ID: "l1",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "l2",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("2001:db8::abcd:7890/128"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "r1",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "r2",
NetID: "routeC",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("2001:db8::abcd:789f/128"),
NetworkType: route.IPv6Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
serverRoutesExpected: 2,
clientNetworkWatchersExpected: 1,
isV6: true,
},
}
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
v6Addr := ""
//goland:noinspection GoBoolExpressions
if !iface.SupportsIPv6() && testCase.isV6 {
t.Skip("Platform does not support IPv6, skipping IPv6 test...")
} else if testCase.isV6 {
v6Addr = "2001:db8::4242:4711/128"
}
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", v6Addr, 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()
@@ -416,7 +846,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
statusRecorder := peer.NewRecorder("https://mgm")
ctx := context.TODO()
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil)
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
_, _, err = routeManager.Init()
@@ -436,7 +866,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
require.NoError(t, err, "should update routes")
expectedWatchers := testCase.clientNetworkWatchersExpected
if testCase.clientNetworkWatchersExpectedAllowed != 0 {
if (runtime.GOOS == "linux" || runtime.GOOS == "windows" || runtime.GOOS == "darwin") && testCase.clientNetworkWatchersExpectedAllowed != 0 {
expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed
}
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")

View File

@@ -6,10 +6,10 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/util/net"
)
// MockManager is the mock instance of a route manager
@@ -20,7 +20,7 @@ type MockManager struct {
StopFunc func()
}
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return nil, nil, nil
}
@@ -64,6 +64,10 @@ func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error {
panic("implement me")
}
func (m *MockManager) ResetV6Routes() {
panic("implement me")
}
// Stop mock implementation of Stop from Manager interface
func (m *MockManager) Stop() {
if m.StopFunc != nil {

View File

@@ -1,155 +0,0 @@
package refcounter
import (
"errors"
"fmt"
"net/netip"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
)
// ErrIgnore can be returned by AddFunc to indicate that the counter not be incremented for the given prefix.
var ErrIgnore = errors.New("ignore")
type Ref[O any] struct {
Count int
Out O
}
type AddFunc[I, O any] func(prefix netip.Prefix, in I) (out O, err error)
type RemoveFunc[I, O any] func(prefix netip.Prefix, out O) error
type Counter[I, O any] struct {
// refCountMap keeps track of the reference Ref for prefixes
refCountMap map[netip.Prefix]Ref[O]
refCountMu sync.Mutex
// idMap keeps track of the prefixes associated with an ID for removal
idMap map[string][]netip.Prefix
idMu sync.Mutex
add AddFunc[I, O]
remove RemoveFunc[I, O]
}
// New creates a new Counter instance
func New[I, O any](add AddFunc[I, O], remove RemoveFunc[I, O]) *Counter[I, O] {
return &Counter[I, O]{
refCountMap: map[netip.Prefix]Ref[O]{},
idMap: map[string][]netip.Prefix{},
add: add,
remove: remove,
}
}
// Increment increments the reference count for the given prefix.
// If this is the first reference to the prefix, the AddFunc is called.
func (rm *Counter[I, O]) Increment(prefix netip.Prefix, in I) (Ref[O], error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
ref := rm.refCountMap[prefix]
log.Tracef("Increasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
// Call AddFunc only if it's a new prefix
if ref.Count == 0 {
log.Tracef("Adding for prefix %s with [%v]", prefix, ref.Out)
out, err := rm.add(prefix, in)
if errors.Is(err, ErrIgnore) {
return ref, nil
}
if err != nil {
return ref, fmt.Errorf("failed to add for prefix %s: %w", prefix, err)
}
ref.Out = out
}
ref.Count++
rm.refCountMap[prefix] = ref
return ref, nil
}
// IncrementWithID increments the reference count for the given prefix and groups it under the given ID.
// If this is the first reference to the prefix, the AddFunc is called.
func (rm *Counter[I, O]) IncrementWithID(id string, prefix netip.Prefix, in I) (Ref[O], error) {
rm.idMu.Lock()
defer rm.idMu.Unlock()
ref, err := rm.Increment(prefix, in)
if err != nil {
return ref, fmt.Errorf("with ID: %w", err)
}
rm.idMap[id] = append(rm.idMap[id], prefix)
return ref, nil
}
// Decrement decrements the reference count for the given prefix.
// If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[I, O]) Decrement(prefix netip.Prefix) (Ref[O], error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
ref, ok := rm.refCountMap[prefix]
if !ok {
log.Tracef("No reference found for prefix %s", prefix)
return ref, nil
}
log.Tracef("Decreasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
if ref.Count == 1 {
log.Tracef("Removing for prefix %s with [%v]", prefix, ref.Out)
if err := rm.remove(prefix, ref.Out); err != nil {
return ref, fmt.Errorf("remove for prefix %s: %w", prefix, err)
}
delete(rm.refCountMap, prefix)
} else {
ref.Count--
rm.refCountMap[prefix] = ref
}
return ref, nil
}
// DecrementWithID decrements the reference count for all prefixes associated with the given ID.
// If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[I, O]) DecrementWithID(id string) error {
rm.idMu.Lock()
defer rm.idMu.Unlock()
var merr *multierror.Error
for _, prefix := range rm.idMap[id] {
if _, err := rm.Decrement(prefix); err != nil {
merr = multierror.Append(merr, err)
}
}
delete(rm.idMap, id)
return nberrors.FormatErrorOrNil(merr)
}
// Flush removes all references and calls RemoveFunc for each prefix.
func (rm *Counter[I, O]) Flush() error {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
var merr *multierror.Error
for prefix := range rm.refCountMap {
log.Tracef("Removing for prefix %s", prefix)
ref := rm.refCountMap[prefix]
if err := rm.remove(prefix, ref.Out); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove for prefix %s: %w", prefix, err))
}
}
rm.refCountMap = map[netip.Prefix]Ref[O]{}
rm.idMap = map[string][]netip.Prefix{}
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -1,7 +0,0 @@
package refcounter
// RouteRefCounter is a Counter for Route, it doesn't take any input on Increment and doesn't use any output on Decrement
type RouteRefCounter = Counter[any, any]
// AllowedIPsRefCounter is a Counter for AllowedIPs, it takes a peer key on Increment and passes it back to Decrement
type AllowedIPsRefCounter = Counter[string, string]

View File

@@ -0,0 +1,127 @@
//go:build !android && !ios
package routemanager
import (
"errors"
"fmt"
"net"
"net/netip"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net"
)
type ref struct {
count int
nexthop netip.Addr
intf *net.Interface
}
type RouteManager struct {
// refCountMap keeps track of the reference ref for prefixes
refCountMap map[netip.Prefix]ref
// prefixMap keeps track of the prefixes associated with a connection ID for removal
prefixMap map[nbnet.ConnectionID][]netip.Prefix
addRoute AddRouteFunc
removeRoute RemoveRouteFunc
mutex sync.Mutex
}
type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf *net.Interface, err error)
type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error
func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager {
// TODO: read initial routing table into refCountMap
return &RouteManager{
refCountMap: map[netip.Prefix]ref{},
prefixMap: map[nbnet.ConnectionID][]netip.Prefix{},
addRoute: addRoute,
removeRoute: removeRoute,
}
}
func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Prefix) error {
rm.mutex.Lock()
defer rm.mutex.Unlock()
ref := rm.refCountMap[prefix]
log.Debugf("Increasing route ref count %d for prefix %s", ref.count, prefix)
// Add route to the system, only if it's a new prefix
if ref.count == 0 {
log.Debugf("Adding route for prefix %s", prefix)
nexthop, intf, err := rm.addRoute(prefix)
if errors.Is(err, ErrRouteNotFound) {
return nil
}
if errors.Is(err, ErrRouteNotAllowed) {
log.Debugf("Adding route for prefix %s: %s", prefix, err)
}
if err != nil {
return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err)
}
ref.nexthop = nexthop
ref.intf = intf
}
ref.count++
rm.refCountMap[prefix] = ref
rm.prefixMap[connID] = append(rm.prefixMap[connID], prefix)
return nil
}
func (rm *RouteManager) RemoveRouteRef(connID nbnet.ConnectionID) error {
rm.mutex.Lock()
defer rm.mutex.Unlock()
prefixes, ok := rm.prefixMap[connID]
if !ok {
log.Debugf("No prefixes found for connection ID %s", connID)
return nil
}
var result *multierror.Error
for _, prefix := range prefixes {
ref := rm.refCountMap[prefix]
log.Debugf("Decreasing route ref count %d for prefix %s", ref.count, prefix)
if ref.count == 1 {
log.Debugf("Removing route for prefix %s", prefix)
// TODO: don't fail if the route is not found
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
continue
}
delete(rm.refCountMap, prefix)
} else {
ref.count--
rm.refCountMap[prefix] = ref
}
}
delete(rm.prefixMap, connID)
return result.ErrorOrNil()
}
// Flush removes all references and routes from the system
func (rm *RouteManager) Flush() error {
rm.mutex.Lock()
defer rm.mutex.Unlock()
var result *multierror.Error
for prefix := range rm.refCountMap {
log.Debugf("Removing route for prefix %s", prefix)
ref := rm.refCountMap[prefix]
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
}
}
rm.refCountMap = map[netip.Prefix]ref{}
rm.prefixMap = map[nbnet.ConnectionID][]netip.Prefix{}
return result.ErrorOrNil()
}

View File

@@ -12,7 +12,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
@@ -71,7 +70,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route)
}
if len(m.routes) > 0 {
err := systemops.EnableIPForwarding()
err := enableIPForwarding(m.wgInterface.Address6() != nil)
if err != nil {
return err
}
@@ -80,7 +79,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route)
return nil
}
func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error {
func (m *defaultServerRouter) removeFromServerNetwork(rt *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("Not removing from server network because context is done")
@@ -88,28 +87,32 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error
default:
m.mux.Lock()
defer m.mux.Unlock()
routerPair, err := routeToRouterPair(route)
routingAddress := m.wgInterface.Address().Masked().String()
if rt.NetworkType == route.IPv6Network {
if m.wgInterface.Address6() == nil {
return fmt.Errorf("attempted to add route for IPv6 even though device has no v6 address")
}
routingAddress = m.wgInterface.Address6().Masked().String()
}
routerPair, err := routeToRouterPair(routingAddress, rt)
if err != nil {
return fmt.Errorf("parse prefix: %w", err)
}
err = m.firewall.RemoveRoutingRules(routerPair)
if err != nil {
return fmt.Errorf("remove routing rules: %w", err)
return err
}
delete(m.routes, route.ID)
delete(m.routes, rt.ID)
state := m.statusRecorder.GetLocalPeerState()
delete(state.Routes, route.Network.String())
delete(state.Routes, rt.Network.String())
m.statusRecorder.UpdateLocalPeerState(state)
return nil
}
}
func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
func (m *defaultServerRouter) addToServerNetwork(rt *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("Not adding to server network because context is done")
@@ -117,8 +120,15 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
default:
m.mux.Lock()
defer m.mux.Unlock()
routingAddress := m.wgInterface.Address().Masked().String()
if rt.NetworkType == route.IPv6Network {
if m.wgInterface.Address6() == nil {
return fmt.Errorf("attempted to add route for IPv6 even though device has no v6 address")
}
routingAddress = m.wgInterface.Address6().Masked().String()
}
routerPair, err := routeToRouterPair(route)
routerPair, err := routeToRouterPair(routingAddress, rt)
if err != nil {
return fmt.Errorf("parse prefix: %w", err)
}
@@ -128,19 +138,13 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
return fmt.Errorf("insert routing rules: %w", err)
}
m.routes[route.ID] = route
m.routes[rt.ID] = rt
state := m.statusRecorder.GetLocalPeerState()
if state.Routes == nil {
state.Routes = map[string]struct{}{}
}
routeStr := route.Network.String()
if route.IsDynamic() {
routeStr = route.Domains.SafeString()
}
state.Routes[routeStr] = struct{}{}
state.Routes[rt.Network.String()] = struct{}{}
m.statusRecorder.UpdateLocalPeerState(state)
return nil
@@ -151,10 +155,17 @@ func (m *defaultServerRouter) cleanUp() {
m.mux.Lock()
defer m.mux.Unlock()
for _, r := range m.routes {
routerPair, err := routeToRouterPair(r)
routingAddress := m.wgInterface.Address().Masked().String()
if r.NetworkType == route.IPv6Network {
if m.wgInterface.Address6() == nil {
log.Errorf("attempted to remove route for IPv6 even though device has no v6 address")
continue
}
routingAddress = m.wgInterface.Address6().Masked().String()
}
routerPair, err := routeToRouterPair(routingAddress, r)
if err != nil {
log.Errorf("Failed to convert route to router pair: %v", err)
continue
log.Errorf("parse prefix: %v", err)
}
err = m.firewall.RemoveRoutingRules(routerPair)
@@ -169,27 +180,15 @@ func (m *defaultServerRouter) cleanUp() {
m.statusRecorder.UpdateLocalPeerState(state)
}
func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) {
// TODO: add ipv6
source := getDefaultPrefix(route.Network)
destination := route.Network.Masked().String()
if route.IsDynamic() {
// TODO: add ipv6
destination = "0.0.0.0/0"
func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) {
parsed, err := netip.ParsePrefix(source)
if err != nil {
return firewall.RouterPair{}, err
}
return firewall.RouterPair{
ID: string(route.ID),
Source: source.String(),
Destination: destination,
Source: parsed.String(),
Destination: route.Network.Masked().String(),
Masquerade: route.Masquerade,
}, nil
}
func getDefaultPrefix(prefix netip.Prefix) netip.Prefix {
if prefix.Addr().Is6() {
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
}
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
}

View File

@@ -1,57 +0,0 @@
package static
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/route"
)
type Route struct {
route *route.Route
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
}
func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *Route {
return &Route{
route: rt,
routeRefCounter: routeRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter,
}
}
// Route route methods
func (r *Route) String() string {
return r.route.Network.String()
}
func (r *Route) AddRoute(context.Context) error {
_, err := r.routeRefCounter.Increment(r.route.Network, nil)
return err
}
func (r *Route) RemoveRoute() error {
_, err := r.routeRefCounter.Decrement(r.route.Network)
return err
}
func (r *Route) AddAllowedIPs(peerKey string) error {
if ref, err := r.allowedIPsRefcounter.Increment(r.route.Network, peerKey); err != nil {
return fmt.Errorf("add allowed IP %s: %w", r.route.Network, err)
} else if ref.Count > 1 && ref.Out != peerKey {
log.Warnf("Prefix [%s] is already routed by peer [%s]. HA routing disabled",
r.route.Network,
ref.Out,
)
}
return nil
}
func (r *Route) RemoveAllowedIPs() error {
_, err := r.allowedIPsRefcounter.Decrement(r.route.Network)
return err
}

View File

@@ -1,103 +0,0 @@
// go:build !android
package sysctl
import (
"fmt"
"net"
"os"
"strconv"
"strings"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/iface"
)
const (
rpFilterPath = "net.ipv4.conf.all.rp_filter"
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
)
// Setup configures sysctl settings for RP filtering and source validation.
func Setup(wgIface *iface.WGIface) (map[string]int, error) {
keys := map[string]int{}
var result *multierror.Error
oldVal, err := Set(srcValidMarkPath, 1, false)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[srcValidMarkPath] = oldVal
}
oldVal, err = Set(rpFilterPath, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[rpFilterPath] = oldVal
}
interfaces, err := net.Interfaces()
if err != nil {
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
}
for _, intf := range interfaces {
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
continue
}
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
oldVal, err := Set(i, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[i] = oldVal
}
}
return keys, nberrors.FormatErrorOrNil(result)
}
// Set sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
func Set(key string, desiredValue int, onlyIfOne bool) (int, error) {
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
currentValue, err := os.ReadFile(path)
if err != nil {
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
}
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
if err != nil && len(currentValue) > 0 {
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
}
if currentV == desiredValue || onlyIfOne && currentV != 1 {
return currentV, nil
}
//nolint:gosec
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
}
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
return currentV, nil
}
// Cleanup resets sysctl settings to their original values.
func Cleanup(originalSettings map[string]int) error {
var result *multierror.Error
for key, value := range originalSettings {
_, err := Set(key, value, false)
if err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}

View File

@@ -0,0 +1,424 @@
//go:build !android && !ios
package routemanager
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
"strconv"
"github.com/hashicorp/go-multierror"
"github.com/libp2p/go-netroute"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
nbnet "github.com/netbirdio/netbird/util/net"
)
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var ErrRouteNotFound = errors.New("route not found")
var ErrRouteNotAllowed = errors.New("route not allowed")
// TODO: fix: for default our wg address now appears as the default gw
func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
addr := netip.IPv4Unspecified()
if prefix.Addr().Is6() {
addr = netip.IPv6Unspecified()
}
defaultGateway, _, err := GetNextHop(addr)
if err != nil && !errors.Is(err, ErrRouteNotFound) {
return fmt.Errorf("get existing route gateway: %s", err)
}
if !prefix.Contains(defaultGateway) {
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix)
return nil
}
gatewayPrefix := netip.PrefixFrom(defaultGateway, 32)
if defaultGateway.Is6() {
gatewayPrefix = netip.PrefixFrom(defaultGateway, 128)
}
ok, err := existsInRouteTable(gatewayPrefix)
if err != nil {
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
}
if ok {
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
return nil
}
gatewayHop, intf, err := GetNextHop(defaultGateway)
if err != nil && !errors.Is(err, ErrRouteNotFound) {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
}
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
return addToRouteTable(gatewayPrefix, gatewayHop, intf)
}
func GetNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
r, err := netroute.New()
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err)
}
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
if err != nil {
log.Debugf("Failed to get route for %s: %v", ip, err)
return netip.Addr{}, nil, ErrRouteNotFound
}
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
if gateway == nil {
if preferredSrc == nil {
return netip.Addr{}, nil, ErrRouteNotFound
}
log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc)
addr, err := ipToAddr(preferredSrc, intf)
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("convert preferred source to address: %w", err)
}
return addr.Unmap(), intf, nil
}
addr, err := ipToAddr(gateway, intf)
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("convert gateway to address: %w", err)
}
return addr, intf, nil
}
// converts a net.IP to a netip.Addr including the zone based on the passed interface
func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip)
}
if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) {
log.Tracef("Adding zone %s to address %s", intf.Name, addr)
if runtime.GOOS == "windows" {
addr = addr.WithZone(strconv.Itoa(intf.Index))
} else {
addr = addr.WithZone(intf.Name)
}
}
return addr.Unmap(), nil
}
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
linkLocalPrefix, err := netip.ParsePrefix("fe80::/10")
if err != nil {
return false, err
}
if prefix.Addr().Is6() && linkLocalPrefix.Contains(prefix.Addr()) {
// The link local prefix is not explicitly part of the routing table, but should be considered as such.
return true, nil
}
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute == prefix {
return true, nil
}
}
return false, nil
}
func isSubRange(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
return true, nil
}
}
return false, nil
}
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop netip.Addr, initialIntf *net.Interface) (netip.Addr, *net.Interface, error) {
addr := prefix.Addr()
switch {
case addr.IsLoopback(),
addr.IsLinkLocalUnicast(),
addr.IsLinkLocalMulticast(),
addr.IsInterfaceLocalMulticast(),
addr.IsUnspecified(),
addr.IsMulticast():
return netip.Addr{}, nil, ErrRouteNotAllowed
}
// Determine the exit interface and next hop for the prefix, so we can add a specific route
nexthop, intf, err := GetNextHop(addr)
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("get next hop: %w", err)
}
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf)
exitNextHop := nexthop
exitIntf := intf
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
if !ok {
return netip.Addr{}, nil, fmt.Errorf("failed to convert vpn address to netip.Addr")
}
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop == vpnAddr || exitIntf != nil && exitIntf.Name == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
exitNextHop = initialNextHop
exitIntf = initialIntf
}
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop)
if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil {
return netip.Addr{}, nil, fmt.Errorf("add route to table: %w", err)
}
return exitNextHop, exitIntf, nil
}
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
// in two /1 prefixes to avoid replacing the existing default route
func genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if prefix == defaultv4 {
if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
return err
}
if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return err
}
// TODO: remove once IPv6 is supported on the interface
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
} else if prefix == defaultv6 {
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
}
return addNonExistingRoute(prefix, intf)
}
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
func addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
ok, err := existsInRouteTable(prefix)
if err != nil {
return fmt.Errorf("exists in route table: %w", err)
}
if ok {
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
return nil
}
ok, err = isSubRange(prefix)
if err != nil {
return fmt.Errorf("sub range: %w", err)
}
if ok {
err := addRouteForCurrentDefaultGateway(prefix)
if err != nil {
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
}
}
return addToRouteTable(prefix, netip.Addr{}, intf)
}
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
// it will remove the split /1 prefixes
func genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if prefix == defaultv4 {
var result *multierror.Error
if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
// TODO: remove once IPv6 is supported on the interface
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
return result.ErrorOrNil()
} else if prefix == defaultv6 {
var result *multierror.Error
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
return result.ErrorOrNil()
}
return removeFromRouteTable(prefix, netip.Addr{}, intf)
}
func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return nil, fmt.Errorf("parse IP address: %s", ip)
}
addr = addr.Unmap()
var prefixLength int
switch {
case addr.Is4():
prefixLength = 32
case addr.Is6():
prefixLength = 128
default:
return nil, fmt.Errorf("invalid IP address: %s", addr)
}
prefix := netip.PrefixFrom(addr, prefixLength)
return &prefix, nil
}
func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
initialNextHopV4, initialIntfV4, err := GetNextHop(netip.IPv4Unspecified())
if err != nil && !errors.Is(err, ErrRouteNotFound) {
log.Errorf("Unable to get initial v4 default next hop: %v", err)
}
initialNextHopV6, initialIntfV6, err := GetNextHop(netip.IPv6Unspecified())
if err != nil && !errors.Is(err, ErrRouteNotFound) {
log.Errorf("Unable to get initial v6 default next hop: %v", err)
}
*routeManager = NewRouteManager(
func(prefix netip.Prefix) (netip.Addr, *net.Interface, error) {
addr := prefix.Addr()
nexthop, intf := initialNextHopV4, initialIntfV4
if addr.Is6() {
nexthop, intf = initialNextHopV6, initialIntfV6
}
return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf)
},
removeFromRouteTable,
)
return setupHooks(*routeManager, initAddresses)
}
func cleanupRoutingWithRouteManager(routeManager *RouteManager) error {
if routeManager == nil {
return nil
}
// TODO: Remove hooks selectively
nbnet.RemoveDialerHooks()
nbnet.RemoveListenerHooks()
if err := routeManager.Flush(); err != nil {
return fmt.Errorf("flush route manager: %w", err)
}
return nil
}
func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := getPrefixFromIP(ip)
if err != nil {
return fmt.Errorf("convert ip to prefix: %w", err)
}
if err := routeManager.AddRouteRef(connID, *prefix); err != nil {
return fmt.Errorf("adding route reference: %v", err)
}
return nil
}
afterHook := func(connID nbnet.ConnectionID) error {
if err := routeManager.RemoveRouteRef(connID); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}
return nil
}
for _, ip := range initAddresses {
if err := beforeHook("init", ip); err != nil {
log.Errorf("Failed to add route reference: %v", err)
}
}
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
if ctx.Err() != nil {
return ctx.Err()
}
var result *multierror.Error
for _, ip := range resolvedIPs {
result = multierror.Append(result, beforeHook(connID, ip.IP))
}
return result.ErrorOrNil()
})
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
return afterHook(connID)
})
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
return beforeHook(connID, ip.IP)
})
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
return afterHook(connID)
})
return beforeHook, afterHook, nil
}

View File

@@ -1,18 +0,0 @@
//go:build darwin || dragonfly || netbsd || openbsd
package systemops
import "syscall"
// filterRoutesByFlags - return true if need to ignore such route message because it consists specific flags.
func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&syscall.RTF_UP == 0 {
return true
}
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
return true
}
return false
}

View File

@@ -1,19 +0,0 @@
//go:build: freebsd
package systemops
import "syscall"
// filterRoutesByFlags - return true if need to ignore such route message because it consists specific flags.
func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&syscall.RTF_UP == 0 {
return true
}
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0 (https://www.freebsd.org/releases/8.0R/relnotes-detailed/)
// a concept of cloned route (a route generated by an entry with RTF_CLONING flag) is deprecated.
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 {
return true
}
return false
}

View File

@@ -1,27 +0,0 @@
package systemops
import (
"net"
"net/netip"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/iface"
)
type Nexthop struct {
IP netip.Addr
Intf *net.Interface
}
type ExclusionCounter = refcounter.Counter[any, Nexthop]
type SysOps struct {
refCounter *ExclusionCounter
wgInterface *iface.WGIface
}
func NewSysOps(wgInterface *iface.WGIface) *SysOps {
return &SysOps{
wgInterface: wgInterface,
}
}

View File

@@ -1,473 +0,0 @@
//go:build !android && !ios
package systemops
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
"strconv"
"github.com/hashicorp/go-multierror"
"github.com/libp2p/go-netroute"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/iface"
nbnet "github.com/netbirdio/netbird/util/net"
)
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var ErrRoutingIsSeparate = errors.New("routing is separate")
func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
log.Errorf("Unable to get initial v4 default next hop: %v", err)
}
initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified())
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
log.Errorf("Unable to get initial v6 default next hop: %v", err)
}
refCounter := refcounter.New(
func(prefix netip.Prefix, _ any) (Nexthop, error) {
initialNexthop := initialNextHopV4
if prefix.Addr().Is6() {
initialNexthop = initialNextHopV6
}
nexthop, err := r.addRouteToNonVPNIntf(prefix, r.wgInterface, initialNexthop)
if errors.Is(err, vars.ErrRouteNotAllowed) || errors.Is(err, vars.ErrRouteNotFound) {
log.Tracef("Adding for prefix %s: %v", prefix, err)
// These errors are not critical but also we should not track and try to remove the routes either.
return nexthop, refcounter.ErrIgnore
}
return nexthop, err
},
r.removeFromRouteTable,
)
r.refCounter = refCounter
return r.setupHooks(initAddresses)
}
func (r *SysOps) cleanupRefCounter() error {
if r.refCounter == nil {
return nil
}
// TODO: Remove hooks selectively
nbnet.RemoveDialerHooks()
nbnet.RemoveListenerHooks()
if err := r.refCounter.Flush(); err != nil {
return fmt.Errorf("flush route manager: %w", err)
}
return nil
}
// TODO: fix: for default our wg address now appears as the default gw
func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
addr := netip.IPv4Unspecified()
if prefix.Addr().Is6() {
addr = netip.IPv6Unspecified()
}
nexthop, err := GetNextHop(addr)
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
return fmt.Errorf("get existing route gateway: %s", err)
}
if !prefix.Contains(nexthop.IP) {
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", nexthop.IP, prefix)
return nil
}
gatewayPrefix := netip.PrefixFrom(nexthop.IP, 32)
if nexthop.IP.Is6() {
gatewayPrefix = netip.PrefixFrom(nexthop.IP, 128)
}
ok, err := existsInRouteTable(gatewayPrefix)
if err != nil {
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
}
if ok {
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
return nil
}
nexthop, err = GetNextHop(nexthop.IP)
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
}
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, nexthop.IP)
return r.addToRouteTable(gatewayPrefix, nexthop)
}
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop Nexthop) (Nexthop, error) {
addr := prefix.Addr()
switch {
case addr.IsLoopback(),
addr.IsLinkLocalUnicast(),
addr.IsLinkLocalMulticast(),
addr.IsInterfaceLocalMulticast(),
addr.IsUnspecified(),
addr.IsMulticast():
return Nexthop{}, vars.ErrRouteNotAllowed
}
// Determine the exit interface and next hop for the prefix, so we can add a specific route
nexthop, err := GetNextHop(addr)
if err != nil {
return Nexthop{}, fmt.Errorf("get next hop: %w", err)
}
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.IP)
exitNextHop := Nexthop{
IP: nexthop.IP,
Intf: nexthop.Intf,
}
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
if !ok {
return Nexthop{}, fmt.Errorf("failed to convert vpn address to netip.Addr")
}
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop)
exitNextHop = initialNextHop
}
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop.IP)
if err := r.addToRouteTable(prefix, exitNextHop); err != nil {
return Nexthop{}, fmt.Errorf("add route to table: %w", err)
}
return exitNextHop, nil
}
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
// in two /1 prefixes to avoid replacing the existing default route
func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
nextHop := Nexthop{netip.Addr{}, intf}
if prefix == vars.Defaultv4 {
if err := r.addToRouteTable(splitDefaultv4_1, nextHop); err != nil {
return err
}
if err := r.addToRouteTable(splitDefaultv4_2, nextHop); err != nil {
if err2 := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return err
}
// TODO: remove once IPv6 is supported on the interface
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
} else if prefix == vars.Defaultv6 {
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
}
return r.addNonExistingRoute(prefix, intf)
}
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
func (r *SysOps) addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
ok, err := existsInRouteTable(prefix)
if err != nil {
return fmt.Errorf("exists in route table: %w", err)
}
if ok {
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
return nil
}
ok, err = isSubRange(prefix)
if err != nil {
return fmt.Errorf("sub range: %w", err)
}
if ok {
if err := r.addRouteForCurrentDefaultGateway(prefix); err != nil {
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
}
}
return r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf})
}
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
// it will remove the split /1 prefixes
func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
nextHop := Nexthop{netip.Addr{}, intf}
if prefix == vars.Defaultv4 {
var result *multierror.Error
if err := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err != nil {
result = multierror.Append(result, err)
}
if err := r.removeFromRouteTable(splitDefaultv4_2, nextHop); err != nil {
result = multierror.Append(result, err)
}
// TODO: remove once IPv6 is supported on the interface
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
result = multierror.Append(result, err)
}
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
result = multierror.Append(result, err)
}
return nberrors.FormatErrorOrNil(result)
} else if prefix == vars.Defaultv6 {
var result *multierror.Error
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
result = multierror.Append(result, err)
}
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
result = multierror.Append(result, err)
}
return nberrors.FormatErrorOrNil(result)
}
return r.removeFromRouteTable(prefix, nextHop)
}
func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {
return fmt.Errorf("convert ip to prefix: %w", err)
}
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, nil); err != nil {
return fmt.Errorf("adding route reference: %v", err)
}
return nil
}
afterHook := func(connID nbnet.ConnectionID) error {
if err := r.refCounter.DecrementWithID(string(connID)); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}
return nil
}
for _, ip := range initAddresses {
if err := beforeHook("init", ip); err != nil {
log.Errorf("Failed to add route reference: %v", err)
}
}
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
if ctx.Err() != nil {
return ctx.Err()
}
var result *multierror.Error
for _, ip := range resolvedIPs {
result = multierror.Append(result, beforeHook(connID, ip.IP))
}
return nberrors.FormatErrorOrNil(result)
})
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
return afterHook(connID)
})
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
return beforeHook(connID, ip.IP)
})
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
return afterHook(connID)
})
return beforeHook, afterHook, nil
}
func GetNextHop(ip netip.Addr) (Nexthop, error) {
r, err := netroute.New()
if err != nil {
return Nexthop{}, fmt.Errorf("new netroute: %w", err)
}
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
if err != nil {
log.Debugf("Failed to get route for %s: %v", ip, err)
return Nexthop{}, vars.ErrRouteNotFound
}
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
if gateway == nil {
if runtime.GOOS == "freebsd" {
return Nexthop{Intf: intf}, nil
}
if preferredSrc == nil {
return Nexthop{}, vars.ErrRouteNotFound
}
log.Debugf("No next hop found for IP %s, using preferred source %s", ip, preferredSrc)
addr, err := ipToAddr(preferredSrc, intf)
if err != nil {
return Nexthop{}, fmt.Errorf("convert preferred source to address: %w", err)
}
return Nexthop{
IP: addr,
Intf: intf,
}, nil
}
addr, err := ipToAddr(gateway, intf)
if err != nil {
return Nexthop{}, fmt.Errorf("convert gateway to address: %w", err)
}
return Nexthop{
IP: addr,
Intf: intf,
}, nil
}
// converts a net.IP to a netip.Addr including the zone based on the passed interface
func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip)
}
if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) {
zone := intf.Name
if runtime.GOOS == "windows" {
zone = strconv.Itoa(intf.Index)
}
log.Tracef("Adding zone %s to address %s", zone, addr)
addr = addr.WithZone(zone)
}
return addr.Unmap(), nil
}
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute == prefix {
return true, nil
}
}
return false, nil
}
func isSubRange(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute.Bits() > vars.MinRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
return true, nil
}
}
return false, nil
}
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
localRoutes, err := hasSeparateRouting()
if err != nil {
if !errors.Is(err, ErrRoutingIsSeparate) {
log.Errorf("Failed to get routes: %v", err)
}
return false, netip.Prefix{}
}
return isVpnRoute(addr, vpnRoutes, localRoutes)
}
func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.Prefix) (bool, netip.Prefix) {
vpnPrefixMap := map[netip.Prefix]struct{}{}
for _, prefix := range vpnRoutes {
vpnPrefixMap[prefix] = struct{}{}
}
// remove vpnRoute duplicates
for _, prefix := range localRoutes {
delete(vpnPrefixMap, prefix)
}
var longestPrefix netip.Prefix
var isVpn bool
combinedRoutes := make([]netip.Prefix, len(vpnRoutes)+len(localRoutes))
copy(combinedRoutes, vpnRoutes)
copy(combinedRoutes[len(vpnRoutes):], localRoutes)
for _, prefix := range combinedRoutes {
// Ignore the default route, it has special handling
if prefix.Bits() == 0 {
continue
}
if prefix.Contains(addr) {
// Longest prefix match
if !longestPrefix.IsValid() || prefix.Bits() > longestPrefix.Bits() {
longestPrefix = prefix
_, isVpn = vpnPrefixMap[prefix]
}
}
}
if !longestPrefix.IsValid() {
// No route matched
return false, netip.Prefix{}
}
// Return true if the longest matching prefix is from vpnRoutes
return isVpn, longestPrefix
}

View File

@@ -1,38 +0,0 @@
//go:build ios || android
package systemops
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net"
)
func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
return nil, nil, nil
}
func (r *SysOps) CleanupRouting() error {
return nil
}
func (r *SysOps) AddVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}
func (r *SysOps) RemoveVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}
func EnableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) {
return false, netip.Prefix{}
}

View File

@@ -1,28 +0,0 @@
//go:build !linux && !ios
package systemops
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
)
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return r.genericAddVPNRoute(prefix, intf)
}
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return r.genericRemoveVPNRoute(prefix, intf)
}
func EnableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func hasSeparateRouting() ([]netip.Prefix, error) {
return getRoutesFromTable()
}

View File

@@ -0,0 +1,33 @@
package routemanager
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
)
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return nil, nil, nil
}
func cleanupRouting() error {
return nil
}
func enableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func addVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}
func removeVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}

View File

@@ -1,6 +1,6 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package systemops
package routemanager
import (
"errors"
@@ -43,7 +43,8 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
return nil, fmt.Errorf("unexpected RIB message type: %d", m.Type)
}
if filterRoutesByFlags(m.Flags) {
if m.Flags&syscall.RTF_UP == 0 ||
m.Flags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
continue
}
@@ -92,7 +93,7 @@ func toNetIP(a route.Addr) netip.Addr {
case *route.Inet6Addr:
ip := netip.AddrFrom16(t.IP)
if t.ZoneID != 0 {
ip = ip.WithZone(strconv.Itoa(t.ZoneID))
ip.WithZone(strconv.Itoa(t.ZoneID))
}
return ip
default:
@@ -100,7 +101,6 @@ func toNetIP(a route.Addr) netip.Addr {
}
}
// ones returns the number of leading ones in the mask.
func ones(a route.Addr) (int, error) {
switch t := a.(type) {
case *route.Inet4Addr:
@@ -114,7 +114,6 @@ func ones(a route.Addr) (int, error) {
}
}
// MsgToRoute converts a route message to a Route.
func MsgToRoute(msg *route.RouteMessage) (*Route, error) {
dstIP, nexthop, dstMask := msg.Addrs[0], msg.Addrs[1], msg.Addrs[2]

View File

@@ -0,0 +1,57 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package routemanager
import (
"testing"
"github.com/stretchr/testify/assert"
"golang.org/x/net/route"
)
func TestBits(t *testing.T) {
tests := []struct {
name string
addr route.Addr
want int
wantErr bool
}{
{
name: "IPv4 all ones",
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 255}},
want: 32,
},
{
name: "IPv4 normal mask",
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 0}},
want: 24,
},
{
name: "IPv6 all ones",
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}},
want: 128,
},
{
name: "IPv6 normal mask",
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0}},
want: 64,
},
{
name: "Unsupported type",
addr: &route.LinkAddr{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ones(tt.addr)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
}
})
}
}

View File

@@ -1,6 +1,6 @@
//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd
//go:build darwin && !ios
package systemops
package routemanager
import (
"fmt"
@@ -13,41 +13,43 @@ import (
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
)
func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
return r.setupRefCounter(initAddresses)
var routeManager *RouteManager
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
}
func (r *SysOps) CleanupRouting() error {
return r.cleanupRefCounter()
func cleanupRouting() error {
return cleanupRoutingWithRouteManager(routeManager)
}
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return r.routeCmd("add", prefix, nexthop)
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return routeCmd("add", prefix, nexthop, intf)
}
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return r.routeCmd("delete", prefix, nexthop)
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return routeCmd("delete", prefix, nexthop, intf)
}
func (r *SysOps) routeCmd(action string, prefix netip.Prefix, nexthop Nexthop) error {
func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
inet := "-inet"
if prefix.Addr().Is6() {
inet = "-inet6"
}
network := prefix.String()
if prefix.IsSingleIP() {
network = prefix.Addr().String()
}
if prefix.Addr().Is6() {
inet = "-inet6"
}
args := []string{"-n", action, inet, network}
if nexthop.IP.IsValid() {
args = append(args, nexthop.IP.Unmap().String())
} else if nexthop.Intf != nil {
args = append(args, "-interface", nexthop.Intf.Name)
if nexthop.IsValid() {
args = append(args, nexthop.Unmap().String())
} else if intf != nil {
args = append(args, "-interface", intf.Name)
}
if err := retryRouteCmd(args); err != nil {

View File

@@ -1,6 +1,6 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
//go:build !ios
package systemops
package routemanager
import (
"fmt"
@@ -13,7 +13,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/route"
)
var expectedVPNint = "utun100"
@@ -36,15 +35,13 @@ func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
intf := &net.Interface{Name: "lo0"}
r := NewSysOps(nil)
var wg sync.WaitGroup
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
if err := addToRouteTable(prefix, netip.Addr{}, intf); err != nil {
t.Errorf("Failed to add route for %s: %v", prefix, err)
}
}(baseIP)
@@ -60,7 +57,7 @@ func TestConcurrentRoutes(t *testing.T) {
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.removeFromRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
if err := removeFromRouteTable(prefix, netip.Addr{}, intf); err != nil {
t.Errorf("Failed to remove route for %s: %v", prefix, err)
}
}(baseIP)
@@ -70,53 +67,6 @@ func TestConcurrentRoutes(t *testing.T) {
wg.Wait()
}
func TestBits(t *testing.T) {
tests := []struct {
name string
addr route.Addr
want int
wantErr bool
}{
{
name: "IPv4 all ones",
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 255}},
want: 32,
},
{
name: "IPv4 normal mask",
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 0}},
want: 24,
},
{
name: "IPv6 all ones",
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}},
want: 128,
},
{
name: "IPv6 normal mask",
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0}},
want: 64,
},
{
name: "Unsupported type",
addr: &route.LinkAddr{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ones(tt.addr)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
}
})
}
}
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
t.Helper()

View File

@@ -0,0 +1,33 @@
package routemanager
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
)
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return nil, nil, nil
}
func cleanupRouting() error {
return nil
}
func enableIPForwarding(includeV6 bool) error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func addVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}
func removeVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}

View File

@@ -1,6 +1,6 @@
//go:build !android
package systemops
package routemanager
import (
"bufio"
@@ -9,15 +9,16 @@ import (
"net"
"net/netip"
"os"
"strconv"
"strings"
"syscall"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
nbnet "github.com/netbirdio/netbird/util/net"
)
@@ -32,10 +33,16 @@ const (
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
ipv4ForwardingPath = "net.ipv4.ip_forward"
rpFilterPath = "net.ipv4.conf.all.rp_filter"
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
)
var ErrTableIDExists = errors.New("ID exists with different name")
var routeManager = &RouteManager{}
// originalSysctl stores the original sysctl values before they are modified
var originalSysctl map[string]int
@@ -75,7 +82,7 @@ func getSetupRules() []ruleParams {
}
}
// SetupRouting establishes the routing configuration for the VPN, including essential rules
// setupRouting establishes the routing configuration for the VPN, including essential rules
// to ensure proper traffic flow for management, locally configured routes, and VPN traffic.
//
// Rule 1 (Main Route Precedence): Safeguards locally installed routes by giving them precedence over
@@ -85,17 +92,17 @@ func getSetupRules() []ruleParams {
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity.
func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) {
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
if isLegacy() {
log.Infof("Using legacy routing setup")
return r.setupRefCounter(initAddresses)
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
}
if err = addRoutingTableName(); err != nil {
log.Errorf("Error adding routing table name: %v", err)
}
originalValues, err := sysctl.Setup(r.wgInterface)
originalValues, err := setupSysctl(wgIface)
if err != nil {
log.Errorf("Error setting up sysctl: %v", err)
sysctlFailed = true
@@ -104,7 +111,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb
defer func() {
if err != nil {
if cleanErr := r.CleanupRouting(); cleanErr != nil {
if cleanErr := cleanupRouting(); cleanErr != nil {
log.Errorf("Error cleaning up routing: %v", cleanErr)
}
}
@@ -116,7 +123,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb
if errors.Is(err, syscall.EOPNOTSUPP) {
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
setIsLegacy(true)
return r.setupRefCounter(initAddresses)
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
}
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
}
@@ -125,12 +132,12 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb
return nil, nil, nil
}
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
// cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
// The function uses error aggregation to report any errors encountered during the cleanup process.
func (r *SysOps) CleanupRouting() error {
func cleanupRouting() error {
if isLegacy() {
return r.cleanupRefCounter()
return cleanupRoutingWithRouteManager(routeManager)
}
var result *multierror.Error
@@ -149,58 +156,46 @@ func (r *SysOps) CleanupRouting() error {
}
}
if err := sysctl.Cleanup(originalSysctl); err != nil {
if err := cleanupSysctl(originalSysctl); err != nil {
result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err))
}
originalSysctl = nil
sysctlFailed = false
return nberrors.FormatErrorOrNil(result)
return result.ErrorOrNil()
}
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return addRoute(prefix, nexthop, syscall.RT_TABLE_MAIN)
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
}
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return removeRoute(prefix, nexthop, syscall.RT_TABLE_MAIN)
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
}
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if isLegacy() {
return r.genericAddVPNRoute(prefix, intf)
return genericAddVPNRoute(prefix, intf)
}
if sysctlFailed && (prefix == vars.Defaultv4 || prefix == vars.Defaultv6) {
if sysctlFailed && (prefix == defaultv4 || prefix == defaultv6) {
log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)")
}
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
// TODO remove this once we have ipv6 support
if prefix == vars.Defaultv4 {
if err := addUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil {
return fmt.Errorf("add blackhole: %w", err)
}
}
if err := addRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil {
if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
return fmt.Errorf("add route: %w", err)
}
return nil
}
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if isLegacy() {
return r.genericRemoveVPNRoute(prefix, intf)
return genericRemoveVPNRoute(prefix, intf)
}
// TODO remove this once we have ipv6 support
if prefix == vars.Defaultv4 {
if err := removeUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil {
return fmt.Errorf("remove unreachable route: %w", err)
}
}
if err := removeRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil {
if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
return fmt.Errorf("remove route: %w", err)
}
return nil
@@ -248,7 +243,7 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) {
}
// addRoute adds a route to a specific routing table identified by tableID.
func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
func addRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
route := &netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
Table: tableID,
@@ -261,7 +256,7 @@ func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
}
route.Dst = ipNet
if err := addNextHop(nexthop, route); err != nil {
if err := addNextHop(addr, intf, route); err != nil {
return fmt.Errorf("add gateway and device: %w", err)
}
@@ -275,6 +270,9 @@ func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
// addUnreachableRoute adds an unreachable route for the specified IP family and routing table.
// ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6.
// tableID specifies the routing table to which the unreachable route will be added.
// TODO should this be kept in for future use? If so, the linter needs to be told that this unreachable function should
//
// be kept
func addUnreachableRoute(prefix netip.Prefix, tableID int) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
@@ -295,6 +293,9 @@ func addUnreachableRoute(prefix netip.Prefix, tableID int) error {
return nil
}
// TODO should this be kept in for future use? If so, the linter needs to be told that this unreachable function should
//
// be kept
func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
@@ -320,7 +321,7 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
}
// removeRoute removes a route from a specific routing table identified by tableID.
func removeRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
func removeRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
@@ -333,7 +334,7 @@ func removeRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
Dst: ipNet,
}
if err := addNextHop(nexthop, route); err != nil {
if err := addNextHop(addr, intf, route); err != nil {
return fmt.Errorf("add gateway and device: %w", err)
}
@@ -366,11 +367,17 @@ func flushRoutes(tableID, family int) error {
}
}
return nberrors.FormatErrorOrNil(result)
return result.ErrorOrNil()
}
func EnableIPForwarding() error {
_, err := sysctl.Set(ipv4ForwardingPath, 1, false)
func enableIPForwarding(includeV6 bool) error {
_, err := setSysctl(ipv4ForwardingPath, 1, false)
if err != nil {
return err
}
if includeV6 {
_, err = setSysctl(ipv4ForwardingPath, 1, false)
}
return err
}
@@ -474,19 +481,19 @@ func removeRule(params ruleParams) error {
}
// addNextHop adds the gateway and device to the route.
func addNextHop(nexthop Nexthop, route *netlink.Route) error {
if nexthop.Intf != nil {
route.LinkIndex = nexthop.Intf.Index
func addNextHop(addr netip.Addr, intf *net.Interface, route *netlink.Route) error {
if intf != nil {
route.LinkIndex = intf.Index
}
if nexthop.IP.IsValid() {
route.Gw = nexthop.IP.AsSlice()
if addr.IsValid() {
route.Gw = addr.AsSlice()
// if zone is set, it means the gateway is a link-local address, so we set the link index
if nexthop.IP.Zone() != "" && nexthop.Intf == nil {
link, err := netlink.LinkByName(nexthop.IP.Zone())
if addr.Zone() != "" && intf == nil {
link, err := netlink.LinkByName(addr.Zone())
if err != nil {
return fmt.Errorf("get link by name for zone %s: %w", nexthop.IP.Zone(), err)
return fmt.Errorf("get link by name for zone %s: %w", addr.Zone(), err)
}
route.LinkIndex = link.Attrs().Index
}
@@ -502,9 +509,82 @@ func getAddressFamily(prefix netip.Prefix) int {
return netlink.FAMILY_V6
}
func hasSeparateRouting() ([]netip.Prefix, error) {
if isLegacy() {
return getRoutesFromTable()
// setupSysctl configures sysctl settings for RP filtering and source validation.
func setupSysctl(wgIface *iface.WGIface) (map[string]int, error) {
keys := map[string]int{}
var result *multierror.Error
oldVal, err := setSysctl(srcValidMarkPath, 1, false)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[srcValidMarkPath] = oldVal
}
return nil, ErrRoutingIsSeparate
oldVal, err = setSysctl(rpFilterPath, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[rpFilterPath] = oldVal
}
interfaces, err := net.Interfaces()
if err != nil {
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
}
for _, intf := range interfaces {
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
continue
}
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
oldVal, err := setSysctl(i, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[i] = oldVal
}
}
return keys, result.ErrorOrNil()
}
// setSysctl sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
func setSysctl(key string, desiredValue int, onlyIfOne bool) (int, error) {
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
currentValue, err := os.ReadFile(path)
if err != nil {
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
}
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
if err != nil && len(currentValue) > 0 {
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
}
if currentV == desiredValue || onlyIfOne && currentV != 1 {
return currentV, nil
}
//nolint:gosec
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
}
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
return currentV, nil
}
func cleanupSysctl(originalSettings map[string]int) error {
var result *multierror.Error
for key, value := range originalSettings {
_, err := setSysctl(key, value, false)
if err != nil {
result = multierror.Append(result, err)
}
}
return result.ErrorOrNil()
}

View File

@@ -1,6 +1,6 @@
//go:build !android
package systemops
package routemanager
import (
"errors"
@@ -14,8 +14,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
)
var expectedVPNint = "wgtest0"
@@ -140,7 +138,7 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) {
if dstIPNet.String() == "0.0.0.0/0" {
var err error
originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4)
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
if err != nil && !errors.Is(err, ErrRouteNotFound) {
t.Logf("Failed to fetch original gateway: %v", err)
}
@@ -195,7 +193,7 @@ func fetchOriginalGateway(family int) (net.IP, int, error) {
}
}
return nil, 0, vars.ErrRouteNotFound
return nil, 0, ErrRouteNotFound
}
func setupDummyInterfacesAndRoutes(t *testing.T) {

View File

@@ -0,0 +1,24 @@
//go:build !linux && !ios
package routemanager
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
)
func enableIPForwarding(includeV6 bool) error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return genericAddVPNRoute(prefix, intf)
}
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return genericRemoveVPNRoute(prefix, intf)
}

View File

@@ -1,11 +1,13 @@
//go:build !android && !ios
package systemops
package routemanager
import (
"bytes"
"context"
"fmt"
"github.com/google/gopacket/routing"
"github.com/netbirdio/netbird/client/firewall"
"net"
"net/netip"
"os"
@@ -46,41 +48,55 @@ func TestAddRemoveRoutes(t *testing.T) {
shouldRouteToWireguard: false,
shouldBeRemoved: false,
},
{
name: "Should Add And Remove Route 2001:db8:1234:5678::/64",
prefix: netip.MustParsePrefix("2001:db8:1234:5678::/64"),
shouldRouteToWireguard: true,
shouldBeRemoved: true,
},
{
name: "Should Not Add Or Remove Route ::1/128",
prefix: netip.MustParsePrefix("::1/128"),
shouldRouteToWireguard: false,
shouldBeRemoved: false,
},
}
for n, testCase := range testCases {
// todo resolve test execution on freebsd
if runtime.GOOS == "freebsd" {
t.Skip("skipping ", testCase.name, " on freebsd")
}
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
v6Addr := ""
hasV6DefaultRoute, err := EnvironmentHasIPv6DefaultRoute()
//goland:noinspection GoBoolExpressions
if (!iface.SupportsIPv6() || !firewall.SupportsIPv6() || !hasV6DefaultRoute || err != nil) && testCase.prefix.Addr().Is6() {
t.Skip("Platform does not support IPv6, skipping IPv6 test...")
} else if testCase.prefix.Addr().Is6() {
v6Addr = "2001:db8::4242:4711/128"
}
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", v6Addr, 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
r := NewSysOps(wgInterface)
_, _, err = r.SetupRouting(nil)
_, _, err = setupRouting(nil, wgInterface)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting())
assert.NoError(t, cleanupRouting())
})
index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
err = r.AddVPNRoute(testCase.prefix, intf)
err = addVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericAddVPNRoute should not return err")
if testCase.shouldRouteToWireguard {
@@ -91,19 +107,23 @@ func TestAddRemoveRoutes(t *testing.T) {
exists, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "existsInRouteTable should not return err")
if exists && testCase.shouldRouteToWireguard {
err = r.RemoveVPNRoute(testCase.prefix, intf)
err = removeVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
prefixNexthop, err := GetNextHop(testCase.prefix.Addr())
prefixGateway, _, err := GetNextHop(testCase.prefix.Addr())
require.NoError(t, err, "GetNextHop should not return err")
internetNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
internetGateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
require.NoError(t, err)
if testCase.prefix.Addr().Is6() {
internetGateway, _, err = GetNextHop(netip.MustParseAddr("::/0"))
}
require.NoError(t, err)
if testCase.shouldBeRemoved {
require.Equal(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to default internet gateway")
require.Equal(t, internetGateway, prefixGateway, "route should be pointing to default internet gateway")
} else {
require.NotEqual(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to a different gateway than the internet gateway")
require.NotEqual(t, internetGateway, prefixGateway, "route should be pointing to a different gateway than the internet gateway")
}
}
})
@@ -111,14 +131,11 @@ func TestAddRemoveRoutes(t *testing.T) {
}
func TestGetNextHop(t *testing.T) {
if runtime.GOOS == "freebsd" {
t.Skip("skipping on freebsd")
}
nexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
gateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err)
}
if !nexthop.IP.IsValid() {
if !gateway.IsValid() {
t.Fatal("should return a gateway")
}
addresses, err := net.InterfaceAddrs()
@@ -140,27 +157,39 @@ func TestGetNextHop(t *testing.T) {
}
}
localIP, err := GetNextHop(testingPrefix.Addr())
localIP, _, err := GetNextHop(testingPrefix.Addr())
if err != nil {
t.Fatal("shouldn't return error: ", err)
}
if !localIP.IP.IsValid() {
if !localIP.IsValid() {
t.Fatal("should return a gateway for local network")
}
if localIP.IP.String() == nexthop.IP.String() {
t.Fatal("local IP should not match with gateway IP")
if localIP.String() == gateway.String() {
t.Fatal("local ip should not match with gateway IP")
}
if localIP.IP.String() != testingIP {
t.Fatalf("local IP should match with testing IP: want %s got %s", testingIP, localIP.IP.String())
if localIP.String() != testingIP {
t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String())
}
}
func TestAddExistAndRemoveRoute(t *testing.T) {
defaultNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
t.Log("defaultNexthop: ", defaultNexthop)
defaultGateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
t.Log("defaultGateway: ", defaultGateway)
if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err)
}
var defaultGateway6 *netip.Addr
hasV6DefaultRoute, err := EnvironmentHasIPv6DefaultRoute()
//goland:noinspection GoBoolExpressions
if iface.SupportsIPv6() && firewall.SupportsIPv6() && hasV6DefaultRoute && err == nil {
gw6, _, err := GetNextHop(netip.MustParseAddr("::"))
gw6 = gw6.WithZone("")
defaultGateway6 = &gw6
t.Log("defaultGateway6: ", defaultGateway6)
if err != nil {
t.Fatal("shouldn't return error when fetching the IPv6 gateway: ", err)
}
}
testCases := []struct {
name string
prefix netip.Prefix
@@ -174,7 +203,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
},
{
name: "Should Not Add Route if overlaps with default gateway",
prefix: netip.MustParsePrefix(defaultNexthop.IP.String() + "/31"),
prefix: netip.MustParsePrefix(defaultGateway.String() + "/31"),
shouldAddRoute: false,
},
{
@@ -195,6 +224,43 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
shouldAddRoute: false,
},
{
name: "Should Add And Remove random Route (IPv6)",
prefix: netip.MustParsePrefix("2001:db8::abcd/128"),
shouldAddRoute: true,
},
{
name: "Should Add Route if bigger network exists (IPv6)",
prefix: netip.MustParsePrefix("2001:db8:b14d:abcd:1234::/96"),
preExistingPrefix: netip.MustParsePrefix("2001:db8:b14d:abcd::/64"),
shouldAddRoute: true,
},
{
name: "Should Add Route if smaller network exists (IPv6)",
prefix: netip.MustParsePrefix("2001:db8:b14d::/48"),
preExistingPrefix: netip.MustParsePrefix("2001:db8:b14d:abcd::/64"),
shouldAddRoute: true,
},
{
name: "Should Not Add Route if same network exists (IPv6)",
prefix: netip.MustParsePrefix("2001:db8:b14d:abcd::/64"),
preExistingPrefix: netip.MustParsePrefix("2001:db8:b14d:abcd::/64"),
shouldAddRoute: false,
},
}
if defaultGateway6 != nil {
testCases = append(testCases, []struct {
name string
prefix netip.Prefix
preExistingPrefix netip.Prefix
shouldAddRoute bool
}{
{
name: "Should Not Add Route if overlaps with default gateway (IPv6)",
prefix: netip.MustParsePrefix(defaultGateway6.String() + "/127"),
shouldAddRoute: false,
},
}...)
}
for n, testCase := range testCases {
@@ -208,12 +274,19 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
t.Setenv("NB_USE_LEGACY_ROUTING", "true")
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
v6Addr := ""
if testCase.prefix.Addr().Is6() && defaultGateway6 == nil {
t.Skip("Platform does not support IPv6, skipping IPv6 test...")
} else if testCase.prefix.Addr().Is6() {
v6Addr = "2001:db8::4242:4711/128"
}
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", v6Addr, 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()
@@ -224,16 +297,14 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
r := NewSysOps(wgInterface)
// Prepare the environment
if testCase.preExistingPrefix.IsValid() {
err := r.AddVPNRoute(testCase.preExistingPrefix, intf)
err := addVPNRoute(testCase.preExistingPrefix, intf)
require.NoError(t, err, "should not return err when adding pre-existing route")
}
// Add the route
err = r.AddVPNRoute(testCase.prefix, intf)
err = addVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err when adding route")
if testCase.shouldAddRoute {
@@ -243,7 +314,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
require.True(t, ok, "route should exist")
// remove route again if added
err = r.RemoveVPNRoute(testCase.prefix, intf)
err = removeVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err")
}
@@ -261,6 +332,11 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
}
func TestIsSubRange(t *testing.T) {
// Note: This test may fail for IPv6 in some environments, where there actually exists another route that the
// determined prefix is a sub-range of.
hasV6DefaultRoute, err := EnvironmentHasIPv6DefaultRoute()
shouldIncludeV6Routes := iface.SupportsIPv6() && firewall.SupportsIPv6() && hasV6DefaultRoute && err == nil
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
@@ -270,7 +346,7 @@ func TestIsSubRange(t *testing.T) {
var nonSubRangeAddressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 {
if !p.Addr().IsLoopback() && (p.Addr().Is4() && p.Bits() < 32) || (p.Addr().Is6() && shouldIncludeV6Routes && p.Bits() < 128) {
p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1)
subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2)
nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked())
@@ -298,31 +374,49 @@ func TestIsSubRange(t *testing.T) {
}
}
func EnvironmentHasIPv6DefaultRoute() (bool, error) {
//goland:noinspection GoBoolExpressions
if runtime.GOOS != "linux" {
// TODO when implementing IPv6 for other operating systems, this should be replaced with code that determines
// whether a default route for IPv6 exists (routing.Router panics on non-linux).
return false, nil
}
router, err := routing.New()
if err != nil {
return false, err
}
routeIface, _, _, err := router.Route(netip.MustParsePrefix("::/0").Addr().AsSlice())
if err != nil {
return false, err
}
return routeIface != nil, nil
}
func TestExistsInRouteTable(t *testing.T) {
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
hasV6DefaultRoute, err := EnvironmentHasIPv6DefaultRoute()
shouldIncludeV6Routes := iface.SupportsIPv6() && firewall.SupportsIPv6() && hasV6DefaultRoute && err == nil
var addressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
switch {
case p.Addr().Is6():
if p.Addr().Is6() && !shouldIncludeV6Routes {
continue
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
case runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast():
continue
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
case runtime.GOOS == "linux" && p.Addr().IsLoopback():
continue
// FreeBSD loopback 127/8 is not added to the routing table
case runtime.GOOS == "freebsd" && p.Addr().IsLoopback():
continue
default:
addressPrefixes = append(addressPrefixes, p.Masked())
}
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
if runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast() {
continue
}
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
if runtime.GOOS == "linux" && p.Addr().IsLoopback() {
continue
}
addressPrefixes = append(addressPrefixes, p.Masked())
}
for _, prefix := range addressPrefixes {
@@ -336,7 +430,7 @@ func TestExistsInRouteTable(t *testing.T) {
}
}
func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface {
func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, ipAddress6CIDR string, listenPort int) *iface.WGIface {
t.Helper()
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
@@ -345,7 +439,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
newNet, err := stdnet.NewNet()
require.NoError(t, err)
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, ipAddress6CIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
require.NoError(t, err, "should create testing WireGuard interface")
err = wgInterface.Create()
@@ -358,52 +452,74 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
return wgInterface
}
func setupRouteAndCleanup(t *testing.T, r *SysOps, prefix netip.Prefix, intf *net.Interface) {
t.Helper()
err := r.AddVPNRoute(prefix, intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = r.RemoveVPNRoute(prefix, intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
}
func setupTestEnv(t *testing.T) {
t.Helper()
setupDummyInterfacesAndRoutes(t)
wgInterface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820)
v6Addr := ""
hasV6DefaultRoute, err := EnvironmentHasIPv6DefaultRoute()
//goland:noinspection GoBoolExpressions
if !iface.SupportsIPv6() || !firewall.SupportsIPv6() || !hasV6DefaultRoute || err != nil {
t.Skip("Platform does not support IPv6, skipping IPv6 test...")
} else {
v6Addr = "2001:db8::4242:4711/128"
}
wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", v6Addr, 51820)
t.Cleanup(func() {
assert.NoError(t, wgInterface.Close())
assert.NoError(t, wgIface.Close())
})
r := NewSysOps(wgInterface)
_, _, err := r.SetupRouting(nil)
_, _, err = setupRouting(nil, wgIface)
require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting())
assert.NoError(t, cleanupRouting())
})
index, err := net.InterfaceByName(wgInterface.Name())
index, err := net.InterfaceByName(wgIface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
intf := &net.Interface{Index: index.Index, Name: wgIface.Name()}
// default route exists in main table and vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("0.0.0.0/0"), intf)
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 10.0.0.0/8 route exists in main table and vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.0.0.0/8"), intf)
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 10.10.0.0/24 more specific route exists in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.10.0.0/24"), intf)
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 127.0.10.0/24 more specific route exists in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("127.0.10.0/24"), intf)
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// unique route in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
}
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
@@ -412,133 +528,17 @@ func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIf
return
}
prefixNexthop, err := GetNextHop(prefix.Addr())
prefixGateway, _, err := GetNextHop(prefix.Addr())
require.NoError(t, err, "GetNextHop should not return err")
nexthop := wgIface.Address().IP.String()
if prefix.Addr().Is6() {
nexthop = wgIface.Address6().IP.String()
}
if invert {
assert.NotEqual(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should not point to wireguard interface IP")
assert.NotEqual(t, nexthop, prefixGateway.String(), "route should not point to wireguard interface IP")
} else {
assert.Equal(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should point to wireguard interface IP")
}
}
func TestIsVpnRoute(t *testing.T) {
tests := []struct {
name string
addr string
vpnRoutes []string
localRoutes []string
expectedVpn bool
expectedPrefix netip.Prefix
}{
{
name: "Match in VPN routes",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Match in local routes",
addr: "10.1.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"),
},
{
name: "No match",
addr: "172.16.0.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.Prefix{},
},
{
name: "Default route ignored",
addr: "192.168.1.1",
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Default route matches but ignored",
addr: "172.16.1.1",
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.Prefix{},
},
{
name: "Longest prefix match local",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.0.0/16"},
localRoutes: []string{"192.168.1.0/24"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Longest prefix match local multiple",
addr: "192.168.0.1",
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"),
},
{
name: "Longest prefix match vpn",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"192.168.0.0/16"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Longest prefix match vpn multiple",
addr: "192.168.0.1",
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"),
},
{
name: "Duplicate prefix in both",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"192.168.1.0/24"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr, err := netip.ParseAddr(tt.addr)
if err != nil {
t.Fatalf("Failed to parse address %s: %v", tt.addr, err)
}
var vpnRoutes, localRoutes []netip.Prefix
for _, route := range tt.vpnRoutes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
t.Fatalf("Failed to parse VPN route %s: %v", route, err)
}
vpnRoutes = append(vpnRoutes, prefix)
}
for _, route := range tt.localRoutes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
t.Fatalf("Failed to parse local route %s: %v", route, err)
}
localRoutes = append(localRoutes, prefix)
}
isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes)
assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value")
assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix")
})
assert.Equal(t, nexthop, prefixGateway.String(), "route should point to wireguard interface IP")
}
}

View File

@@ -1,11 +1,10 @@
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
package systemops
package routemanager
import (
"fmt"
"net"
"runtime"
"strings"
"testing"
"time"
@@ -86,10 +85,6 @@ var testCases = []testCase{
func TestRouting(t *testing.T) {
for _, tc := range testCases {
// todo resolve test execution on freebsd
if runtime.GOOS == "freebsd" {
t.Skip("skipping ", tc.name, " on freebsd")
}
t.Run(tc.name, func(t *testing.T) {
setupTestEnv(t)

View File

@@ -1,6 +1,6 @@
//go:build windows
package systemops
package routemanager
import (
"fmt"
@@ -17,7 +17,8 @@ import (
"github.com/yusufpapurcu/wmi"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
)
type MSFT_NetRoute struct {
@@ -56,42 +57,14 @@ var prefixList []netip.Prefix
var lastUpdate time.Time
var mux = sync.Mutex{}
func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
return r.setupRefCounter(initAddresses)
var routeManager *RouteManager
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
}
func (r *SysOps) CleanupRouting() error {
return r.cleanupRefCounter()
}
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
if nexthop.IP.Zone() != "" && nexthop.Intf == nil {
zone, err := strconv.Atoi(nexthop.IP.Zone())
if err != nil {
return fmt.Errorf("invalid zone: %w", err)
}
nexthop.Intf = &net.Interface{Index: zone}
}
return addRouteCmd(prefix, nexthop)
}
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
args := []string{"delete", prefix.String()}
if nexthop.IP.IsValid() {
ip := nexthop.IP.WithZone("")
args = append(args, ip.Unmap().String())
}
routeCmd := uspfilter.GetSystem32Command("route")
out, err := exec.Command(routeCmd, args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("remove route: %w", err)
}
return nil
func cleanupRouting() error {
return cleanupRoutingWithRouteManager(routeManager)
}
func getRoutesFromTable() ([]netip.Prefix, error) {
@@ -120,7 +93,7 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
func GetRoutes() ([]Route, error) {
var entries []MSFT_NetRoute
query := `SELECT DestinationPrefix, Nexthop, InterfaceIndex, InterfaceAlias, AddressFamily FROM MSFT_NetRoute`
query := `SELECT DestinationPrefix, NextHop, InterfaceIndex, InterfaceAlias, AddressFamily FROM MSFT_NetRoute`
if err := wmi.QueryNamespace(query, &entries, `ROOT\StandardCimv2`); err != nil {
return nil, fmt.Errorf("get routes: %w", err)
}
@@ -145,10 +118,6 @@ func GetRoutes() ([]Route, error) {
Index: int(entry.InterfaceIndex),
Name: entry.InterfaceAlias,
}
if nexthop.Is6() && (nexthop.IsLinkLocalUnicast() || nexthop.IsLinkLocalMulticast()) {
nexthop = nexthop.WithZone(strconv.Itoa(int(entry.InterfaceIndex)))
}
}
routes = append(routes, Route{
@@ -188,12 +157,11 @@ func GetNeighbors() ([]Neighbor, error) {
return neighbors, nil
}
func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error {
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
args := []string{"add", prefix.String()}
if nexthop.IP.IsValid() {
ip := nexthop.IP.WithZone("")
args = append(args, ip.Unmap().String())
if nexthop.IsValid() {
args = append(args, nexthop.Unmap().String())
} else {
addr := "0.0.0.0"
if prefix.Addr().Is6() {
@@ -202,8 +170,8 @@ func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error {
args = append(args, addr)
}
if nexthop.Intf != nil {
args = append(args, "if", strconv.Itoa(nexthop.Intf.Index))
if intf != nil {
args = append(args, "if", strconv.Itoa(intf.Index))
}
routeCmd := uspfilter.GetSystem32Command("route")
@@ -217,6 +185,37 @@ func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error {
return nil
}
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
if nexthop.Zone() != "" && intf == nil {
zone, err := strconv.Atoi(nexthop.Zone())
if err != nil {
return fmt.Errorf("invalid zone: %w", err)
}
intf = &net.Interface{Index: zone}
nexthop.WithZone("")
}
return addRouteCmd(prefix, nexthop, intf)
}
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ *net.Interface) error {
args := []string{"delete", prefix.String()}
if nexthop.IsValid() {
nexthop.WithZone("")
args = append(args, nexthop.Unmap().String())
}
routeCmd := uspfilter.GetSystem32Command("route")
out, err := exec.Command(routeCmd, args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("remove route: %w", err)
}
return nil
}
func isCacheDisabled() bool {
return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true"
}

View File

@@ -1,4 +1,4 @@
package systemops
package routemanager
import (
"context"
@@ -29,7 +29,7 @@ type FindNetRouteOutput struct {
InterfaceIndex int `json:"InterfaceIndex"`
InterfaceAlias string `json:"InterfaceAlias"`
AddressFamily int `json:"AddressFamily"`
NextHop string `json:"Nexthop"`
NextHop string `json:"NextHop"`
DestinationPrefix string `json:"DestinationPrefix"`
}
@@ -166,7 +166,7 @@ func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOut
host, _, err := net.SplitHostPort(destination)
require.NoError(t, err)
script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, Nexthop, DestinationPrefix | ConvertTo-Json`, host)
script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, NextHop, DestinationPrefix | ConvertTo-Json`, host)
out, err := exec.Command("powershell", "-Command", script).Output()
require.NoError(t, err, "Failed to execute Find-NetRoute")
@@ -207,7 +207,7 @@ func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR str
}
func fetchOriginalGateway() (*RouteInfo, error) {
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object Nexthop, RouteMetric, InterfaceAlias | ConvertTo-Json")
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json")
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err)

View File

@@ -1,29 +0,0 @@
package util
import (
"fmt"
"net"
"net/netip"
)
// GetPrefixFromIP returns a netip.Prefix from a net.IP address.
func GetPrefixFromIP(ip net.IP) (netip.Prefix, error) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return netip.Prefix{}, fmt.Errorf("parse IP address: %s", ip)
}
addr = addr.Unmap()
var prefixLength int
switch {
case addr.Is4():
prefixLength = 32
case addr.Is6():
prefixLength = 128
default:
return netip.Prefix{}, fmt.Errorf("invalid IP address: %s", addr)
}
prefix := netip.PrefixFrom(addr, prefixLength)
return prefix, nil
}

View File

@@ -1,16 +0,0 @@
package vars
import (
"errors"
"net/netip"
)
const MinRangeBits = 7
var (
ErrRouteNotFound = errors.New("route not found")
ErrRouteNotAllowed = errors.New("route not allowed")
Defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
Defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
)

View File

@@ -3,11 +3,11 @@ package routeselector
import (
"fmt"
"slices"
"strings"
"github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/errors"
route "github.com/netbirdio/netbird/route"
)
@@ -30,10 +30,10 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
rs.selectedRoutes = map[route.NetID]struct{}{}
}
var err *multierror.Error
var multiErr *multierror.Error
for _, route := range routes {
if !slices.Contains(allRoutes, route) {
err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route))
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
continue
}
@@ -41,7 +41,11 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
}
rs.selectAll = false
return errors.FormatErrorOrNil(err)
if multiErr != nil {
multiErr.ErrorFormat = formatError
}
return multiErr.ErrorOrNil()
}
// SelectAllRoutes sets the selector to select all routes.
@@ -61,17 +65,21 @@ func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.
}
}
var err *multierror.Error
var multiErr *multierror.Error
for _, route := range routes {
if !slices.Contains(allRoutes, route) {
err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route))
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
continue
}
delete(rs.selectedRoutes, route)
}
return errors.FormatErrorOrNil(err)
if multiErr != nil {
multiErr.ErrorFormat = formatError
}
return multiErr.ErrorOrNil()
}
// DeselectAllRoutes deselects all routes, effectively disabling route selection.
@@ -103,3 +111,18 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
}
return filtered
}
func formatError(es []error) string {
if len(es) == 1 {
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
}
points := make([]string, len(es))
for i, err := range es {
points[i] = fmt.Sprintf("* %s", err)
}
return fmt.Sprintf(
"%d errors occurred:\n\t%s",
len(es), strings.Join(points, "\n\t"))
}

View File

@@ -261,15 +261,15 @@ func TestRouteSelector_FilterSelected(t *testing.T) {
require.NoError(t, err)
routes := route.HAMap{
"route1|10.0.0.0/8": {},
"route2|192.168.0.0/16": {},
"route3|172.16.0.0/12": {},
"route1-10.0.0.0/8": {},
"route2-192.168.0.0/16": {},
"route3-172.16.0.0/12": {},
}
filtered := rs.FilterSelected(routes)
assert.Equal(t, route.HAMap{
"route1|10.0.0.0/8": {},
"route2|192.168.0.0/16": {},
"route1-10.0.0.0/8": {},
"route2-192.168.0.0/16": {},
}, filtered)
}

View File

@@ -8,13 +8,9 @@ import (
log "github.com/sirupsen/logrus"
)
func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
func NewFactory(ctx context.Context, wgPort int) *Factory {
f := &Factory{wgPort: wgPort}
if userspace {
return f
}
ebpfProxy := NewWGEBPFProxy(ctx, wgPort)
err := ebpfProxy.listen()
if err != nil {

Some files were not shown because too many files have changed in this diff Show More