mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 07:26:14 -04:00
Compare commits
2 Commits
test/engin
...
debug-0.29
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f30995707a | ||
|
|
5c8bfb7cea |
2
.github/workflows/golang-test-freebsd.yml
vendored
2
.github/workflows/golang-test-freebsd.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
||||
time go test -timeout 1m -failfast ./dns/...
|
||||
time go test -timeout 1m -failfast ./encryption/...
|
||||
time go test -timeout 1m -failfast ./formatter/...
|
||||
time go test -timeout 1m -failfast ./client/iface/...
|
||||
time go test -timeout 1m -failfast ./iface/...
|
||||
time go test -timeout 1m -failfast ./route/...
|
||||
time go test -timeout 1m -failfast ./sharedsock/...
|
||||
time go test -timeout 1m -failfast ./signal/...
|
||||
|
||||
4
.github/workflows/golang-test-linux.yml
vendored
4
.github/workflows/golang-test-linux.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres']
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
@@ -80,7 +80,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Generate Iface Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./client/iface/
|
||||
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./iface/
|
||||
|
||||
- name: Generate Shared Sock Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
|
||||
|
||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif
|
||||
ignore_words_list: erro,clienta,hastable,
|
||||
skip: go.mod,go.sum
|
||||
only_warn: 1
|
||||
golangci:
|
||||
|
||||
1
.github/workflows/install-script-test.yml
vendored
1
.github/workflows/install-script-test.yml
vendored
@@ -13,7 +13,6 @@ concurrency:
|
||||
jobs:
|
||||
test-install-script:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
|
||||
76
.github/workflows/release.yml
vendored
76
.github/workflows/release.yml
vendored
@@ -3,14 +3,15 @@ name: Release
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "v*"
|
||||
- 'v*'
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.14"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
GORELEASER_VER: "v1.14.1"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
||||
|
||||
@@ -20,7 +21,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
flags: ""
|
||||
steps:
|
||||
@@ -33,16 +34,19 @@ jobs:
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
-
|
||||
name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
- name: Set up Go
|
||||
-
|
||||
name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
-
|
||||
name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
@@ -51,19 +55,24 @@ jobs:
|
||||
key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-releaser-
|
||||
- name: Install modules
|
||||
-
|
||||
name: Install modules
|
||||
run: go mod tidy
|
||||
- name: check git status
|
||||
-
|
||||
name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Set up QEMU
|
||||
-
|
||||
name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v2
|
||||
- name: Set up Docker Buildx
|
||||
-
|
||||
name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
- name: Login to Docker hub
|
||||
-
|
||||
name: Login to Docker hub
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
username: netbirdio
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Install OS build dependencies
|
||||
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
||||
@@ -76,31 +85,35 @@ jobs:
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --clean ${{ env.flags }}
|
||||
args: release --rm-dist ${{ env.flags }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
- name: upload non tags for debug purposes
|
||||
-
|
||||
name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release
|
||||
path: dist/
|
||||
retention-days: 3
|
||||
- name: upload linux packages
|
||||
-
|
||||
name: upload linux packages
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: linux-packages
|
||||
path: dist/netbird_linux**
|
||||
retention-days: 3
|
||||
- name: upload windows packages
|
||||
-
|
||||
name: upload windows packages
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: windows-packages
|
||||
path: dist/netbird_windows**
|
||||
retention-days: 3
|
||||
- name: upload macos packages
|
||||
-
|
||||
name: upload macos packages
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: macos-packages
|
||||
@@ -132,7 +145,7 @@ jobs:
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
~/.cache/go-build
|
||||
key: ${{ runner.os }}-ui-go-releaser-${{ hashFiles('**/go.sum') }}
|
||||
@@ -156,7 +169,7 @@ jobs:
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
|
||||
args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||
@@ -174,16 +187,19 @@ jobs:
|
||||
steps:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
-
|
||||
name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
- name: Set up Go
|
||||
-
|
||||
name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
-
|
||||
name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
@@ -192,19 +208,23 @@ jobs:
|
||||
key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-ui-go-releaser-darwin-
|
||||
- name: Install modules
|
||||
-
|
||||
name: Install modules
|
||||
run: go mod tidy
|
||||
- name: check git status
|
||||
-
|
||||
name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Run GoReleaser
|
||||
-
|
||||
name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
|
||||
args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: upload non tags for debug purposes
|
||||
-
|
||||
name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-ui-darwin
|
||||
@@ -213,7 +233,7 @@ jobs:
|
||||
|
||||
trigger_signer:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [release, release_ui, release_ui_darwin]
|
||||
needs: [release,release_ui,release_ui_darwin]
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- name: Trigger binaries sign pipelines
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
version: 2
|
||||
|
||||
project_name: netbird
|
||||
builds:
|
||||
- id: netbird
|
||||
@@ -24,7 +22,7 @@ builds:
|
||||
goarch: 386
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
tags:
|
||||
- load_wgnt_from_rsrc
|
||||
|
||||
@@ -44,19 +42,19 @@ builds:
|
||||
- softfloat
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
tags:
|
||||
- load_wgnt_from_rsrc
|
||||
|
||||
- id: netbird-mgmt
|
||||
dir: management
|
||||
env:
|
||||
- CGO_ENABLED=1
|
||||
- >-
|
||||
{{- if eq .Runtime.Goos "linux" }}
|
||||
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
|
||||
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
|
||||
{{- end }}
|
||||
- CGO_ENABLED=1
|
||||
- >-
|
||||
{{- if eq .Runtime.Goos "linux" }}
|
||||
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
|
||||
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
|
||||
{{- end }}
|
||||
binary: netbird-mgmt
|
||||
goos:
|
||||
- linux
|
||||
@@ -66,7 +64,7 @@ builds:
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
|
||||
- id: netbird-signal
|
||||
dir: signal
|
||||
@@ -80,7 +78,7 @@ builds:
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
|
||||
- id: netbird-relay
|
||||
dir: relay
|
||||
@@ -94,7 +92,7 @@ builds:
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
|
||||
archives:
|
||||
- builds:
|
||||
@@ -102,6 +100,7 @@ archives:
|
||||
- netbird-static
|
||||
|
||||
nfpms:
|
||||
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
description: Netbird client.
|
||||
homepage: https://netbird.io/
|
||||
@@ -417,9 +416,10 @@ docker_manifests:
|
||||
- netbirdio/management:{{ .Version }}-debug-amd64
|
||||
|
||||
brews:
|
||||
- ids:
|
||||
-
|
||||
ids:
|
||||
- default
|
||||
repository:
|
||||
tap:
|
||||
owner: netbirdio
|
||||
name: homebrew-tap
|
||||
token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}"
|
||||
@@ -436,7 +436,7 @@ brews:
|
||||
uploads:
|
||||
- name: debian
|
||||
ids:
|
||||
- netbird-deb
|
||||
- netbird-deb
|
||||
mode: archive
|
||||
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
|
||||
username: dev@wiretrustee.com
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
version: 2
|
||||
|
||||
project_name: netbird-ui
|
||||
builds:
|
||||
- id: netbird-ui
|
||||
@@ -13,7 +11,7 @@ builds:
|
||||
- amd64
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
|
||||
- id: netbird-ui-windows
|
||||
dir: client/ui
|
||||
@@ -28,7 +26,7 @@ builds:
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
- -H windowsgui
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
|
||||
archives:
|
||||
- id: linux-arch
|
||||
@@ -41,6 +39,7 @@ archives:
|
||||
- netbird-ui-windows
|
||||
|
||||
nfpms:
|
||||
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
description: Netbird client UI.
|
||||
homepage: https://netbird.io/
|
||||
@@ -78,7 +77,7 @@ nfpms:
|
||||
uploads:
|
||||
- name: debian
|
||||
ids:
|
||||
- netbird-ui-deb
|
||||
- netbird-ui-deb
|
||||
mode: archive
|
||||
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
|
||||
username: dev@wiretrustee.com
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
version: 2
|
||||
|
||||
project_name: netbird-ui
|
||||
builds:
|
||||
- id: netbird-ui-darwin
|
||||
@@ -19,7 +17,7 @@ builds:
|
||||
- softfloat
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
tags:
|
||||
- load_wgnt_from_rsrc
|
||||
|
||||
@@ -30,4 +28,4 @@ archives:
|
||||
checksum:
|
||||
name_template: "{{ .ProjectName }}_darwin_checksums.txt"
|
||||
changelog:
|
||||
disable: true
|
||||
skip: true
|
||||
@@ -96,7 +96,7 @@ They can be executed from the repository root before every push or PR:
|
||||
|
||||
**Goreleaser**
|
||||
```shell
|
||||
goreleaser build --snapshot --clean
|
||||
goreleaser --snapshot --rm-dist
|
||||
```
|
||||
**golangci-lint**
|
||||
```shell
|
||||
|
||||
@@ -49,8 +49,6 @@
|
||||
|
||||

|
||||
|
||||
### NetBird on Lawrence Systems (Video)
|
||||
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
||||
|
||||
### Key features
|
||||
|
||||
@@ -64,7 +62,6 @@
|
||||
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
|
||||
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
|
||||
| | | | | <ul><li> - \[x] Docker </ul></li> |
|
||||
|
||||
### Quickstart with NetBird Cloud
|
||||
|
||||
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
@@ -16,6 +15,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
@@ -26,7 +26,7 @@ type ConnectionListener interface {
|
||||
|
||||
// TunAdapter export internal TunAdapter for mobile
|
||||
type TunAdapter interface {
|
||||
device.TunAdapter
|
||||
iface.TunAdapter
|
||||
}
|
||||
|
||||
// IFaceDiscover export internal IFaceDiscover for mobile
|
||||
@@ -51,7 +51,7 @@ func init() {
|
||||
// Client struct manage the life circle of background service
|
||||
type Client struct {
|
||||
cfgFile string
|
||||
tunAdapter device.TunAdapter
|
||||
tunAdapter iface.TunAdapter
|
||||
iFaceDiscover IFaceDiscover
|
||||
recorder *peer.Status
|
||||
ctxCancel context.CancelFunc
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
func TestInitCommands(t *testing.T) {
|
||||
|
||||
@@ -805,9 +805,6 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
|
||||
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
|
||||
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
|
||||
}
|
||||
|
||||
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
|
||||
|
||||
for i, route := range peer.Routes {
|
||||
peer.Routes[i] = a.AnonymizeIPString(route)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -33,12 +34,18 @@ func startTestingServices(t *testing.T) string {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
testDir := t.TempDir()
|
||||
config.Datadir = testDir
|
||||
err = util.CopyFileContents("../testdata/store.json", filepath.Join(testDir, "store.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, signalLis := startSignal(t)
|
||||
signalAddr := signalLis.Addr().String()
|
||||
config.Signal.URI = signalAddr
|
||||
|
||||
_, mgmLis := startManagement(t, config, "../testdata/store.sqlite")
|
||||
_, mgmLis := startManagement(t, config)
|
||||
mgmAddr := mgmLis.Addr().String()
|
||||
return mgmAddr
|
||||
}
|
||||
@@ -50,7 +57,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
srv, err := sig.NewServer(context.Background(), otel.Meter(""))
|
||||
srv, err := sig.NewServer(otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
|
||||
sigProto.RegisterSignalExchangeServer(s, srv)
|
||||
@@ -63,7 +70,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
return s, lis
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) {
|
||||
func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
|
||||
t.Helper()
|
||||
|
||||
lis, err := net.Listen("tcp", ":0")
|
||||
@@ -71,7 +78,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
store, cleanUp, err := mgmt.NewTestStoreFromSqlite(context.Background(), testFile, t.TempDir())
|
||||
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -15,11 +15,11 @@ import (
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
)
|
||||
|
||||
func formatError(es []error) string {
|
||||
if len(es) == 1 {
|
||||
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
|
||||
if len(es) == 0 {
|
||||
return fmt.Sprintf("0 error occurred:\n\t* %s", es[0])
|
||||
}
|
||||
|
||||
points := make([]string, len(es))
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
import "github.com/netbirdio/netbird/iface"
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
Name() string
|
||||
Address() device.WGAddress
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
SetFilter(device.PacketFilter) error
|
||||
SetFilter(iface.PacketFilter) error
|
||||
}
|
||||
|
||||
@@ -19,22 +19,24 @@ const (
|
||||
// rules chains contains the effective ACL rules
|
||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
|
||||
|
||||
postRoutingMark = "0x000007e4"
|
||||
)
|
||||
|
||||
type aclManager struct {
|
||||
iptablesClient *iptables.IPTables
|
||||
wgIface iFaceMapper
|
||||
routingFwChainName string
|
||||
iptablesClient *iptables.IPTables
|
||||
wgIface iFaceMapper
|
||||
routeingFwChainName string
|
||||
|
||||
entries map[string][][]string
|
||||
ipsetStore *ipsetStore
|
||||
}
|
||||
|
||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
|
||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routeingFwChainName string) (*aclManager, error) {
|
||||
m := &aclManager{
|
||||
iptablesClient: iptablesClient,
|
||||
wgIface: wgIface,
|
||||
routingFwChainName: routingFwChainName,
|
||||
iptablesClient: iptablesClient,
|
||||
wgIface: wgIface,
|
||||
routeingFwChainName: routeingFwChainName,
|
||||
|
||||
entries: make(map[string][][]string),
|
||||
ipsetStore: newIpsetStore(),
|
||||
@@ -59,7 +61,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *aclManager) AddPeerFiltering(
|
||||
func (m *aclManager) AddFiltering(
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
@@ -125,7 +127,7 @@ func (m *aclManager) AddPeerFiltering(
|
||||
return nil, fmt.Errorf("rule already exists")
|
||||
}
|
||||
|
||||
if err := m.iptablesClient.Append("filter", chain, specs...); err != nil {
|
||||
if err := m.iptablesClient.Insert("filter", chain, 1, specs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -137,16 +139,28 @@ func (m *aclManager) AddPeerFiltering(
|
||||
chain: chain,
|
||||
}
|
||||
|
||||
return []firewall.Rule{rule}, nil
|
||||
if !shouldAddToPrerouting(protocol, dPort, direction) {
|
||||
return []firewall.Rule{rule}, nil
|
||||
}
|
||||
|
||||
rulePrerouting, err := m.addPreroutingFilter(ipsetName, string(protocol), dPortVal, ip)
|
||||
if err != nil {
|
||||
return []firewall.Rule{rule}, err
|
||||
}
|
||||
return []firewall.Rule{rule, rulePrerouting}, nil
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *aclManager) DeleteRule(rule firewall.Rule) error {
|
||||
r, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
|
||||
if r.chain == "PREROUTING" {
|
||||
goto DELETERULE
|
||||
}
|
||||
|
||||
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
|
||||
// delete IP from ruleset IPs list and ipset
|
||||
if _, ok := ipsetList.ips[r.ip]; ok {
|
||||
@@ -171,7 +185,14 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
}
|
||||
}
|
||||
|
||||
err := m.iptablesClient.Delete(tableName, r.chain, r.specs...)
|
||||
DELETERULE:
|
||||
var table string
|
||||
if r.chain == "PREROUTING" {
|
||||
table = "mangle"
|
||||
} else {
|
||||
table = "filter"
|
||||
}
|
||||
err := m.iptablesClient.Delete(table, r.chain, r.specs...)
|
||||
if err != nil {
|
||||
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
|
||||
}
|
||||
@@ -182,6 +203,44 @@ func (m *aclManager) Reset() error {
|
||||
return m.cleanChains()
|
||||
}
|
||||
|
||||
func (m *aclManager) addPreroutingFilter(ipsetName string, protocol string, port string, ip net.IP) (*Rule, error) {
|
||||
var src []string
|
||||
if ipsetName != "" {
|
||||
src = []string{"-m", "set", "--set", ipsetName, "src"}
|
||||
} else {
|
||||
src = []string{"-s", ip.String()}
|
||||
}
|
||||
specs := []string{
|
||||
"-d", m.wgIface.Address().IP.String(),
|
||||
"-p", protocol,
|
||||
"--dport", port,
|
||||
"-j", "MARK", "--set-mark", postRoutingMark,
|
||||
}
|
||||
|
||||
specs = append(src, specs...)
|
||||
|
||||
ok, err := m.iptablesClient.Exists("mangle", "PREROUTING", specs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check rule: %w", err)
|
||||
}
|
||||
if ok {
|
||||
return nil, fmt.Errorf("rule already exists")
|
||||
}
|
||||
|
||||
if err := m.iptablesClient.Insert("mangle", "PREROUTING", 1, specs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
ruleID: uuid.New().String(),
|
||||
specs: specs,
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
chain: "PREROUTING",
|
||||
}
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// todo write less destructive cleanup mechanism
|
||||
func (m *aclManager) cleanChains() error {
|
||||
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
|
||||
@@ -232,6 +291,25 @@ func (m *aclManager) cleanChains() error {
|
||||
}
|
||||
}
|
||||
|
||||
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
|
||||
if err != nil {
|
||||
log.Debugf("failed to list chains: %s", err)
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
for _, rule := range m.entries["PREROUTING"] {
|
||||
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
err = m.iptablesClient.ClearChain("mangle", "PREROUTING")
|
||||
if err != nil {
|
||||
log.Debugf("failed to clear %s chain: %s", "PREROUTING", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||
if err := ipset.Flush(ipsetName); err != nil {
|
||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
@@ -260,9 +338,17 @@ func (m *aclManager) createDefaultChains() error {
|
||||
|
||||
for chainName, rules := range m.entries {
|
||||
for _, rule := range rules {
|
||||
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||
return err
|
||||
if chainName == "FORWARD" {
|
||||
// position 2 because we add it after router's, jump rule
|
||||
if err := m.iptablesClient.InsertUnique(tableName, "FORWARD", 2, rule...); err != nil {
|
||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := m.iptablesClient.AppendUnique(tableName, chainName, rule...); err != nil {
|
||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -270,29 +356,40 @@ func (m *aclManager) createDefaultChains() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed.
|
||||
// We want to make sure our traffic is not dropped by existing rules.
|
||||
|
||||
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||
|
||||
// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
|
||||
func (m *aclManager) seedInitialEntries() {
|
||||
m.appendToEntries("INPUT",
|
||||
[]string{"-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||
|
||||
established := getConntrackEstablished()
|
||||
m.appendToEntries("INPUT",
|
||||
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||
|
||||
m.appendToEntries("INPUT",
|
||||
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameInputRules})
|
||||
|
||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
||||
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
|
||||
|
||||
m.appendToEntries("OUTPUT",
|
||||
[]string{"-o", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||
|
||||
m.appendToEntries("OUTPUT",
|
||||
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||
|
||||
m.appendToEntries("OUTPUT",
|
||||
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameOutputRules})
|
||||
|
||||
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
|
||||
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules})
|
||||
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||
m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...))
|
||||
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
|
||||
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
||||
m.appendToEntries("FORWARD",
|
||||
[]string{"-o", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
|
||||
m.appendToEntries("FORWARD",
|
||||
[]string{"-i", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
|
||||
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", m.routeingFwChainName})
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routeingFwChainName})
|
||||
|
||||
m.appendToEntries("PREROUTING",
|
||||
[]string{"-t", "mangle", "-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().IP.String(), "-m", "mark", "--mark", postRoutingMark})
|
||||
}
|
||||
|
||||
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
||||
@@ -359,3 +456,18 @@ func transformIPsetName(ipsetName string, sPort, dPort string) string {
|
||||
return ipsetName
|
||||
}
|
||||
}
|
||||
|
||||
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
|
||||
if proto == "all" {
|
||||
return false
|
||||
}
|
||||
|
||||
if direction != firewall.RuleDirectionIN {
|
||||
return false
|
||||
}
|
||||
|
||||
if dPort == nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -4,14 +4,13 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
// Manager of iptables firewall
|
||||
@@ -22,7 +21,7 @@ type Manager struct {
|
||||
|
||||
ipv4Client *iptables.IPTables
|
||||
aclMgr *aclManager
|
||||
router *router
|
||||
router *routerManager
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
@@ -44,12 +43,12 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
||||
ipv4Client: iptablesClient,
|
||||
}
|
||||
|
||||
m.router, err = newRouter(context, iptablesClient, wgIface)
|
||||
m.router, err = newRouterManager(context, iptablesClient)
|
||||
if err != nil {
|
||||
log.Debugf("failed to initialize route related chains: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
|
||||
m.aclMgr, err = newAclManager(iptablesClient, wgIface, m.router.RouteingFwChainName())
|
||||
if err != nil {
|
||||
log.Debugf("failed to initialize ACL manager: %s", err)
|
||||
return nil, err
|
||||
@@ -58,10 +57,10 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// AddPeerFiltering adds a rule to the firewall
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// Comment will be ignored because some system this feature is not supported
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
func (m *Manager) AddFiltering(
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
@@ -74,62 +73,33 @@ func (m *Manager) AddPeerFiltering(
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
|
||||
return m.aclMgr.AddFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
sources [] netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if !destination.Addr().Is4() {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.aclMgr.DeletePeerRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.DeleteRouteRule(rule)
|
||||
return m.aclMgr.DeleteRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddNatRule(pair)
|
||||
return m.router.InsertRoutingRules(pair)
|
||||
}
|
||||
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveNatRule(pair)
|
||||
}
|
||||
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||
return m.router.RemoveRoutingRules(pair)
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
@@ -155,7 +125,7 @@ func (m *Manager) AllowNetbird() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := m.AddPeerFiltering(
|
||||
_, err := m.AddFiltering(
|
||||
net.ParseIP("0.0.0.0"),
|
||||
"all",
|
||||
nil,
|
||||
@@ -168,7 +138,7 @@ func (m *Manager) AllowNetbird() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
|
||||
}
|
||||
_, err = m.AddPeerFiltering(
|
||||
_, err = m.AddFiltering(
|
||||
net.ParseIP("0.0.0.0"),
|
||||
"all",
|
||||
nil,
|
||||
@@ -183,7 +153,3 @@ func (m *Manager) AllowNetbird() error {
|
||||
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
func getConntrackEstablished() []string {
|
||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
@@ -11,24 +11,9 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
var ifaceMock = &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMock struct {
|
||||
NameFunc func() string
|
||||
@@ -55,8 +40,23 @@ func TestIptablesManager(t *testing.T) {
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
mock := &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(context.Background(), ifaceMock)
|
||||
manager, err := Create(context.Background(), mock)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(time.Second)
|
||||
@@ -72,7 +72,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
t.Run("add first rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Values: []int{8080}}
|
||||
rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule1 {
|
||||
@@ -87,7 +87,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
port := &fw.Port{
|
||||
Values: []int{8043: 8046},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(
|
||||
rule2, err = manager.AddFiltering(
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
@@ -99,7 +99,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
|
||||
t.Run("delete first rule", func(t *testing.T) {
|
||||
for _, r := range rule1 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
err := manager.DeleteRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
|
||||
@@ -108,7 +108,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
for _, r := range rule2 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
err := manager.DeleteRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
|
||||
@@ -119,7 +119,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
// add second rule
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{Values: []int{5353}}
|
||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Reset()
|
||||
@@ -170,7 +170,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
t.Run("add first rule with set", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Values: []int{8080}}
|
||||
rule1, err = manager.AddPeerFiltering(
|
||||
rule1, err = manager.AddFiltering(
|
||||
ip, "tcp", nil, port, fw.RuleDirectionOUT,
|
||||
fw.ActionAccept, "default", "accept HTTP traffic",
|
||||
)
|
||||
@@ -189,7 +189,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
port := &fw.Port{
|
||||
Values: []int{443},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(
|
||||
rule2, err = manager.AddFiltering(
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
|
||||
"default", "accept HTTPS traffic from ports range",
|
||||
)
|
||||
@@ -202,7 +202,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
|
||||
t.Run("delete first rule", func(t *testing.T) {
|
||||
for _, r := range rule1 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
err := manager.DeleteRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
|
||||
@@ -211,7 +211,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
for _, r := range rule2 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
err := manager.DeleteRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
|
||||
@@ -269,9 +269,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
@@ -5,478 +5,368 @@ package iptables
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/nadoo/ipset"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
)
|
||||
|
||||
const (
|
||||
ipv4Nat = "netbird-rt-nat"
|
||||
Ipv4Forwarding = "netbird-rt-forwarding"
|
||||
ipv4Nat = "netbird-rt-nat"
|
||||
)
|
||||
|
||||
// constants needed to manage and create iptable rules
|
||||
const (
|
||||
tableFilter = "filter"
|
||||
tableNat = "nat"
|
||||
chainFORWARD = "FORWARD"
|
||||
chainPOSTROUTING = "POSTROUTING"
|
||||
chainRTNAT = "NETBIRD-RT-NAT"
|
||||
chainRTFWD = "NETBIRD-RT-FWD"
|
||||
routingFinalForwardJump = "ACCEPT"
|
||||
routingFinalNatJump = "MASQUERADE"
|
||||
|
||||
matchSet = "--match-set"
|
||||
)
|
||||
|
||||
type routeFilteringRuleParams struct {
|
||||
Sources []netip.Prefix
|
||||
Destination netip.Prefix
|
||||
Proto firewall.Protocol
|
||||
SPort *firewall.Port
|
||||
DPort *firewall.Port
|
||||
Direction firewall.RuleDirection
|
||||
Action firewall.Action
|
||||
SetName string
|
||||
type routerManager struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
iptablesClient *iptables.IPTables
|
||||
rules map[string][]string
|
||||
}
|
||||
|
||||
type router struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
iptablesClient *iptables.IPTables
|
||||
rules map[string][]string
|
||||
ipsetCounter *refcounter.Counter[string, []netip.Prefix, struct{}]
|
||||
wgIface iFaceMapper
|
||||
legacyManagement bool
|
||||
}
|
||||
|
||||
func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
|
||||
func newRouterManager(parentCtx context.Context, iptablesClient *iptables.IPTables) (*routerManager, error) {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
r := &router{
|
||||
m := &routerManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
iptablesClient: iptablesClient,
|
||||
rules: make(map[string][]string),
|
||||
wgIface: wgIface,
|
||||
}
|
||||
|
||||
r.ipsetCounter = refcounter.New(
|
||||
r.createIpSet,
|
||||
func(name string, _ struct{}) error {
|
||||
return r.deleteIpSet(name)
|
||||
},
|
||||
)
|
||||
|
||||
if err := ipset.Init(); err != nil {
|
||||
return nil, fmt.Errorf("init ipset: %w", err)
|
||||
}
|
||||
|
||||
err := r.cleanUpDefaultForwardRules()
|
||||
err := m.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("cleanup routing rules: %s", err)
|
||||
log.Errorf("failed to cleanup routing rules: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
err = r.createContainers()
|
||||
err = m.createContainers()
|
||||
if err != nil {
|
||||
log.Errorf("create containers for route: %s", err)
|
||||
log.Errorf("failed to create containers for route: %s", err)
|
||||
}
|
||||
return r, err
|
||||
return m, err
|
||||
}
|
||||
|
||||
func (r *router) AddRouteFiltering(
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||
return ruleKey, nil
|
||||
// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain
|
||||
func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
err := i.insertRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var setName string
|
||||
if len(sources) > 1 {
|
||||
setName = firewall.GenerateSetName(sources)
|
||||
if _, err := r.ipsetCounter.Increment(setName, sources); err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
params := routeFilteringRuleParams{
|
||||
Sources: sources,
|
||||
Destination: destination,
|
||||
Proto: proto,
|
||||
SPort: sPort,
|
||||
DPort: dPort,
|
||||
Action: action,
|
||||
SetName: setName,
|
||||
}
|
||||
|
||||
rule := genRouteFilteringRuleSpec(params)
|
||||
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
return nil, fmt.Errorf("add route rule: %v", err)
|
||||
}
|
||||
|
||||
r.rules[string(ruleKey)] = rule
|
||||
|
||||
return ruleKey, nil
|
||||
}
|
||||
|
||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
ruleKey := rule.GetRuleID()
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
setName := r.findSetNameInRule(rule)
|
||||
|
||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
return fmt.Errorf("delete route rule: %v", err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if setName != "" {
|
||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||
return fmt.Errorf("failed to remove ipset: %w", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Debugf("route rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) findSetNameInRule(rule []string) string {
|
||||
for i, arg := range rule {
|
||||
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
|
||||
return rule[i+3]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) (struct{}, error) {
|
||||
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
|
||||
return struct{}{}, fmt.Errorf("create set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
for _, prefix := range sources {
|
||||
if err := ipset.AddPrefix(setName, prefix); err != nil {
|
||||
return struct{}{}, fmt.Errorf("add element to set %s: %w", setName, err)
|
||||
}
|
||||
}
|
||||
|
||||
return struct{}{}, nil
|
||||
}
|
||||
|
||||
func (r *router) deleteIpSet(setName string) error {
|
||||
if err := ipset.Destroy(setName); err != nil {
|
||||
return fmt.Errorf("destroy set %s: %w", setName, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddNatRule inserts an iptables rule pair into the nat chain
|
||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
if r.legacyManagement {
|
||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("add legacy routing rule: %w", err)
|
||||
}
|
||||
err = i.insertRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, firewall.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !pair.Masquerade {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.addNatRule(pair); err != nil {
|
||||
return fmt.Errorf("add nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("add inverse nat rule: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
|
||||
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
} else {
|
||||
log.Debugf("legacy forwarding rule %s not found", ruleKey)
|
||||
err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLegacyManagement returns the current legacy management mode
|
||||
func (r *router) GetLegacyManagement() bool {
|
||||
return r.legacyManagement
|
||||
}
|
||||
// insertRoutingRule inserts an iptables rule
|
||||
func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
|
||||
var err error
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management mode
|
||||
func (r *router) SetLegacyManagement(isLegacy bool) {
|
||||
r.legacyManagement = isLegacy
|
||||
}
|
||||
|
||||
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
|
||||
func (r *router) RemoveAllLegacyRouteRules() error {
|
||||
var merr *multierror.Error
|
||||
for k, rule := range r.rules {
|
||||
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
|
||||
continue
|
||||
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
||||
rule := genRuleSpec(jump, pair.Source, pair.Destination)
|
||||
existingRule, found := i.rules[ruleKey]
|
||||
if found {
|
||||
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
||||
}
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||
delete(i.rules, ruleKey)
|
||||
}
|
||||
|
||||
err = i.iptablesClient.Insert(table, chain, 1, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
||||
}
|
||||
|
||||
i.rules[ruleKey] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains
|
||||
func (i *routerManager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||
err := i.removeRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = i.removeRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, firewall.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !pair.Masquerade {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = i.removeRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = i.removeRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, firewall.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *routerManager) removeRoutingRule(keyFormat, table, chain string, pair firewall.RouterPair) error {
|
||||
var err error
|
||||
|
||||
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
||||
existingRule, found := i.rules[ruleKey]
|
||||
if found {
|
||||
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
delete(i.rules, ruleKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) Reset() error {
|
||||
var merr *multierror.Error
|
||||
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
r.rules = make(map[string][]string)
|
||||
|
||||
if err := r.ipsetCounter.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
func (i *routerManager) RouteingFwChainName() string {
|
||||
return chainRTFWD
|
||||
}
|
||||
|
||||
func (r *router) cleanUpDefaultForwardRules() error {
|
||||
err := r.cleanJumpRules()
|
||||
func (i *routerManager) Reset() error {
|
||||
err := i.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules = make(map[string][]string)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *routerManager) cleanUpDefaultForwardRules() error {
|
||||
err := i.cleanJumpRules()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("flushing routing related tables")
|
||||
for _, chain := range []string{chainRTFWD, chainRTNAT} {
|
||||
table := tableFilter
|
||||
if chain == chainRTNAT {
|
||||
table = tableNat
|
||||
}
|
||||
|
||||
ok, err := r.iptablesClient.ChainExists(table, chain)
|
||||
if err != nil {
|
||||
log.Errorf("failed check chain %s, error: %v", chain, err)
|
||||
return err
|
||||
} else if ok {
|
||||
err = r.iptablesClient.ClearAndDeleteChain(table, chain)
|
||||
if err != nil {
|
||||
log.Errorf("failed cleaning chain %s, error: %v", chain, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) createContainers() error {
|
||||
for _, chain := range []string{chainRTFWD, chainRTNAT} {
|
||||
if err := r.createAndSetupChain(chain); err != nil {
|
||||
return fmt.Errorf("create chain %s: %v", chain, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.insertEstablishedRule(chainRTFWD); err != nil {
|
||||
return fmt.Errorf("insert established rule: %v", err)
|
||||
}
|
||||
|
||||
return r.addJumpRules()
|
||||
}
|
||||
|
||||
func (r *router) createAndSetupChain(chain string) error {
|
||||
table := r.getTableForChain(chain)
|
||||
|
||||
if err := r.iptablesClient.NewChain(table, chain); err != nil {
|
||||
return fmt.Errorf("failed creating chain %s, error: %v", chain, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) getTableForChain(chain string) string {
|
||||
if chain == chainRTNAT {
|
||||
return tableNat
|
||||
}
|
||||
return tableFilter
|
||||
}
|
||||
|
||||
func (r *router) insertEstablishedRule(chain string) error {
|
||||
establishedRule := getConntrackEstablished()
|
||||
|
||||
err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert established rule: %v", err)
|
||||
}
|
||||
|
||||
ruleKey := "established-" + chain
|
||||
r.rules[ruleKey] = establishedRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) addJumpRules() error {
|
||||
rule := []string{"-j", chainRTNAT}
|
||||
err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
|
||||
ok, err := i.iptablesClient.ChainExists(tableFilter, chainRTFWD)
|
||||
if err != nil {
|
||||
log.Errorf("failed check chain %s,error: %v", chainRTFWD, err)
|
||||
return err
|
||||
}
|
||||
r.rules[ipv4Nat] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) cleanJumpRules() error {
|
||||
rule, found := r.rules[ipv4Nat]
|
||||
if found {
|
||||
err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
|
||||
} else if ok {
|
||||
err = i.iptablesClient.ClearAndDeleteChain(tableFilter, chainRTFWD)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err)
|
||||
log.Errorf("failed cleaning chain %s,error: %v", chainRTFWD, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
|
||||
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
|
||||
ok, err = i.iptablesClient.ChainExists(tableNat, chainRTNAT)
|
||||
if err != nil {
|
||||
log.Errorf("failed check chain %s,error: %v", chainRTNAT, err)
|
||||
return err
|
||||
} else if ok {
|
||||
err = i.iptablesClient.ClearAndDeleteChain(tableNat, chainRTNAT)
|
||||
if err != nil {
|
||||
log.Errorf("failed cleaning chain %s,error: %v", chainRTNAT, err)
|
||||
return err
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
}
|
||||
|
||||
rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse)
|
||||
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil {
|
||||
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
|
||||
return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
} else {
|
||||
log.Debugf("nat rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
|
||||
intdir := "-i"
|
||||
if inverse {
|
||||
intdir = "-o"
|
||||
}
|
||||
return []string{intdir, intf, "-s", source.String(), "-d", destination.String(), "-j", jump}
|
||||
}
|
||||
|
||||
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
||||
var rule []string
|
||||
|
||||
if params.SetName != "" {
|
||||
rule = append(rule, "-m", "set", matchSet, params.SetName, "src")
|
||||
} else if len(params.Sources) > 0 {
|
||||
source := params.Sources[0]
|
||||
rule = append(rule, "-s", source.String())
|
||||
}
|
||||
|
||||
rule = append(rule, "-d", params.Destination.String())
|
||||
|
||||
if params.Proto != firewall.ProtocolALL {
|
||||
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
|
||||
rule = append(rule, applyPort("--sport", params.SPort)...)
|
||||
rule = append(rule, applyPort("--dport", params.DPort)...)
|
||||
}
|
||||
|
||||
rule = append(rule, "-j", actionToStr(params.Action))
|
||||
|
||||
return rule
|
||||
}
|
||||
|
||||
func applyPort(flag string, port *firewall.Port) []string {
|
||||
if port == nil {
|
||||
func (i *routerManager) createContainers() error {
|
||||
if i.rules[Ipv4Forwarding] != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if port.IsRange && len(port.Values) == 2 {
|
||||
return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])}
|
||||
errMSGFormat := "failed creating chain %s,error: %v"
|
||||
err := i.createChain(tableFilter, chainRTFWD)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, chainRTFWD, err)
|
||||
}
|
||||
|
||||
if len(port.Values) > 1 {
|
||||
portList := make([]string, len(port.Values))
|
||||
for i, p := range port.Values {
|
||||
portList[i] = strconv.Itoa(p)
|
||||
}
|
||||
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
|
||||
err = i.createChain(tableNat, chainRTNAT)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, chainRTNAT, err)
|
||||
}
|
||||
|
||||
return []string{flag, strconv.Itoa(port.Values[0])}
|
||||
err = i.addJumpRules()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while creating jump rules: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addJumpRules create jump rules to send packets to NetBird chains
|
||||
func (i *routerManager) addJumpRules() error {
|
||||
rule := []string{"-j", chainRTFWD}
|
||||
err := i.iptablesClient.Insert(tableFilter, chainFORWARD, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[Ipv4Forwarding] = rule
|
||||
|
||||
rule = []string{"-j", chainRTNAT}
|
||||
err = i.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[ipv4Nat] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanJumpRules cleans jump rules that was sending packets to NetBird chains
|
||||
func (i *routerManager) cleanJumpRules() error {
|
||||
var err error
|
||||
errMSGFormat := "failed cleaning rule from chain %s,err: %v"
|
||||
rule, found := i.rules[Ipv4Forwarding]
|
||||
if found {
|
||||
err = i.iptablesClient.DeleteIfExists(tableFilter, chainFORWARD, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, chainFORWARD, err)
|
||||
}
|
||||
}
|
||||
rule, found = i.rules[ipv4Nat]
|
||||
if found {
|
||||
err = i.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, chainPOSTROUTING, err)
|
||||
}
|
||||
}
|
||||
|
||||
rules, err := i.iptablesClient.List("nat", "POSTROUTING")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list rules: %s", err)
|
||||
}
|
||||
|
||||
for _, ruleString := range rules {
|
||||
if !strings.Contains(ruleString, "NETBIRD") {
|
||||
continue
|
||||
}
|
||||
rule := strings.Fields(ruleString)
|
||||
err := i.iptablesClient.DeleteIfExists("nat", "POSTROUTING", rule[2:]...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete postrouting jump rule: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
rules, err = i.iptablesClient.List(tableFilter, "FORWARD")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list rules in FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
for _, ruleString := range rules {
|
||||
if !strings.Contains(ruleString, "NETBIRD") {
|
||||
continue
|
||||
}
|
||||
rule := strings.Fields(ruleString)
|
||||
err := i.iptablesClient.DeleteIfExists(tableFilter, "FORWARD", rule[2:]...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete FORWARD jump rule: %s", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *routerManager) createChain(table, newChain string) error {
|
||||
chains, err := i.iptablesClient.ListChains(table)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't get %s table chains, error: %v", table, err)
|
||||
}
|
||||
|
||||
shouldCreateChain := true
|
||||
for _, chain := range chains {
|
||||
if chain == newChain {
|
||||
shouldCreateChain = false
|
||||
}
|
||||
}
|
||||
|
||||
if shouldCreateChain {
|
||||
err = i.iptablesClient.NewChain(table, newChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
|
||||
}
|
||||
|
||||
// Add the loopback return rule to the NAT chain
|
||||
loopbackRule := []string{"-o", "lo", "-j", "RETURN"}
|
||||
err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err)
|
||||
}
|
||||
|
||||
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
|
||||
}
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// addNATRule appends an iptables rule pair to the nat chain
|
||||
func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
||||
rule := genRuleSpec(jump, pair.Source, pair.Destination)
|
||||
existingRule, found := i.rules[ruleKey]
|
||||
if found {
|
||||
err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
delete(i.rules, ruleKey)
|
||||
}
|
||||
|
||||
// inserting after loopback ignore rule
|
||||
err := i.iptablesClient.Insert(table, chain, 2, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
i.rules[ruleKey] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// genRuleSpec generates rule specification
|
||||
func genRuleSpec(jump, source, destination string) []string {
|
||||
return []string{"-s", source, "-d", destination, "-j", jump}
|
||||
}
|
||||
|
||||
func getIptablesRuleType(table string) string {
|
||||
ruleType := "forwarding"
|
||||
if table == tableNat {
|
||||
ruleType = "nat"
|
||||
}
|
||||
return ruleType
|
||||
}
|
||||
|
||||
@@ -4,13 +4,11 @@ package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -30,7 +28,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
|
||||
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
||||
require.NoError(t, err, "should return a valid iptables manager")
|
||||
|
||||
defer func() {
|
||||
@@ -39,22 +37,26 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
|
||||
require.Len(t, manager.rules, 2, "should have created rules map")
|
||||
|
||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
|
||||
exists, err := manager.iptablesClient.Exists(tableFilter, chainFORWARD, manager.rules[Ipv4Forwarding]...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainFORWARD)
|
||||
require.True(t, exists, "forwarding rule should exist")
|
||||
|
||||
exists, err = manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||
require.True(t, exists, "postrouting rule should exist")
|
||||
|
||||
pair := firewall.RouterPair{
|
||||
ID: "abc",
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.100.0/24"),
|
||||
Source: "100.100.100.1/32",
|
||||
Destination: "100.100.100.0/24",
|
||||
Masquerade: true,
|
||||
}
|
||||
forward4Rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
|
||||
forward4Rule := genRuleSpec(routingFinalForwardJump, pair.Source, pair.Destination)
|
||||
|
||||
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, ifaceMock.Name(), false)
|
||||
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination)
|
||||
|
||||
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
@@ -63,7 +65,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
}
|
||||
|
||||
func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
||||
|
||||
if !isIptablesSupported() {
|
||||
t.SkipNow()
|
||||
@@ -74,7 +76,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
|
||||
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
defer func() {
|
||||
@@ -84,13 +86,35 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
err = manager.AddNatRule(testCase.InputPair)
|
||||
err = manager.InsertRoutingRules(testCase.InputPair)
|
||||
require.NoError(t, err, "forwarding pair should be inserted")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
|
||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||
forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
|
||||
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||
require.True(t, exists, "forwarding rule should exist")
|
||||
|
||||
foundRule, found := manager.rules[forwardRuleKey]
|
||||
require.True(t, found, "forwarding rule should exist in the manager map")
|
||||
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
|
||||
|
||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||
require.True(t, exists, "income forwarding rule should exist")
|
||||
|
||||
foundRule, found = manager.rules[inForwardRuleKey]
|
||||
require.True(t, found, "income forwarding rule should exist in the manager map")
|
||||
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
if testCase.InputPair.Masquerade {
|
||||
require.True(t, exists, "nat rule should be created")
|
||||
@@ -103,8 +127,8 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
require.False(t, foundNat, "nat rule should not exist in the map")
|
||||
}
|
||||
|
||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
||||
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
|
||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
@@ -122,7 +146,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
|
||||
|
||||
if !isIptablesSupported() {
|
||||
t.SkipNow()
|
||||
@@ -132,7 +156,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
|
||||
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
|
||||
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
defer func() {
|
||||
_ = manager.Reset()
|
||||
@@ -140,14 +164,26 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
|
||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||
forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
|
||||
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
|
||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
||||
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
|
||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
@@ -155,14 +191,28 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.RemoveNatRule(testCase.InputPair)
|
||||
err = manager.RemoveRoutingRules(testCase.InputPair)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||
require.False(t, exists, "forwarding rule should not exist")
|
||||
|
||||
_, found := manager.rules[forwardRuleKey]
|
||||
require.False(t, found, "forwarding rule should exist in the manager map")
|
||||
|
||||
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||
require.False(t, exists, "income forwarding rule should not exist")
|
||||
|
||||
_, found = manager.rules[inForwardRuleKey]
|
||||
require.False(t, found, "income forwarding rule should exist in the manager map")
|
||||
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
require.False(t, exists, "nat rule should not exist")
|
||||
|
||||
_, found := manager.rules[natRuleKey]
|
||||
_, found = manager.rules[natRuleKey]
|
||||
require.False(t, found, "nat rule should exist in the manager map")
|
||||
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
||||
@@ -171,175 +221,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
|
||||
_, found = manager.rules[inNatRuleKey]
|
||||
require.False(t, found, "income nat rule should exist in the manager map")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
if !isIptablesSupported() {
|
||||
t.Skip("iptables not supported on this system")
|
||||
}
|
||||
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "Failed to create iptables client")
|
||||
|
||||
r, err := newRouter(context.Background(), iptablesClient, ifaceMock)
|
||||
require.NoError(t, err, "Failed to create router manager")
|
||||
|
||||
defer func() {
|
||||
err := r.Reset()
|
||||
require.NoError(t, err, "Failed to reset router")
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sources []netip.Prefix
|
||||
destination netip.Prefix
|
||||
proto firewall.Protocol
|
||||
sPort *firewall.Port
|
||||
dPort *firewall.Port
|
||||
direction firewall.RuleDirection
|
||||
action firewall.Action
|
||||
expectSet bool
|
||||
}{
|
||||
{
|
||||
name: "Basic TCP rule with single source",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: nil,
|
||||
dPort: &firewall.Port{Values: []int{80}},
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "UDP rule with multiple sources",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
proto: firewall.ProtocolUDP,
|
||||
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionDrop,
|
||||
expectSet: true,
|
||||
},
|
||||
{
|
||||
name: "All protocols rule",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
proto: firewall.ProtocolALL,
|
||||
sPort: nil,
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "ICMP rule",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
proto: firewall.ProtocolICMP,
|
||||
sPort: nil,
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "TCP rule with multiple source ports",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
||||
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "UDP rule with single IP and port range",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
proto: firewall.ProtocolUDP,
|
||||
sPort: nil,
|
||||
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionDrop,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "TCP rule with source and destination ports",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
||||
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
|
||||
dPort: &firewall.Port{Values: []int{22}},
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "Drop all incoming traffic",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
destination: netip.MustParsePrefix("192.168.0.0/24"),
|
||||
proto: firewall.ProtocolALL,
|
||||
sPort: nil,
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionDrop,
|
||||
expectSet: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddRouteFiltering failed")
|
||||
|
||||
// Check if the rule is in the internal map
|
||||
rule, ok := r.rules[ruleKey.GetRuleID()]
|
||||
assert.True(t, ok, "Rule not found in internal map")
|
||||
|
||||
// Log the internal rule
|
||||
t.Logf("Internal rule: %v", rule)
|
||||
|
||||
// Check if the rule exists in iptables
|
||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...)
|
||||
assert.NoError(t, err, "Failed to check rule existence")
|
||||
assert.True(t, exists, "Rule not found in iptables")
|
||||
|
||||
// Verify rule content
|
||||
params := routeFilteringRuleParams{
|
||||
Sources: tt.sources,
|
||||
Destination: tt.destination,
|
||||
Proto: tt.proto,
|
||||
SPort: tt.sPort,
|
||||
DPort: tt.dPort,
|
||||
Action: tt.action,
|
||||
SetName: "",
|
||||
}
|
||||
|
||||
expectedRule := genRouteFilteringRuleSpec(params)
|
||||
|
||||
if tt.expectSet {
|
||||
setName := firewall.GenerateSetName(tt.sources)
|
||||
params.SetName = setName
|
||||
expectedRule = genRouteFilteringRuleSpec(params)
|
||||
|
||||
// Check if the set was created
|
||||
_, exists := r.ipsetCounter.Get(setName)
|
||||
assert.True(t, exists, "IPSet not created")
|
||||
}
|
||||
|
||||
assert.Equal(t, expectedRule, rule, "Rule content mismatch")
|
||||
|
||||
// Clean up
|
||||
err = r.DeleteRouteRule(ruleKey)
|
||||
require.NoError(t, err, "Failed to delete rule")
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,21 +1,15 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
ForwardingFormatPrefix = "netbird-fwd-"
|
||||
ForwardingFormat = "netbird-fwd-%s-%t"
|
||||
NatFormat = "netbird-nat-%s-%t"
|
||||
NatFormat = "netbird-nat-%s"
|
||||
ForwardingFormat = "netbird-fwd-%s"
|
||||
InNatFormat = "netbird-nat-in-%s"
|
||||
InForwardingFormat = "netbird-fwd-in-%s"
|
||||
)
|
||||
|
||||
// Rule abstraction should be implemented by each firewall manager
|
||||
@@ -55,11 +49,11 @@ type Manager interface {
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
AllowNetbird() error
|
||||
|
||||
// AddPeerFiltering adds a rule to the firewall
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
AddPeerFiltering(
|
||||
AddFiltering(
|
||||
ip net.IP,
|
||||
proto Protocol,
|
||||
sPort *Port,
|
||||
@@ -70,25 +64,17 @@ type Manager interface {
|
||||
comment string,
|
||||
) ([]Rule, error)
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
DeletePeerRule(rule Rule) error
|
||||
// DeleteRule from the firewall by rule definition
|
||||
DeleteRule(rule Rule) error
|
||||
|
||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||
IsServerRouteSupported() bool
|
||||
|
||||
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error)
|
||||
// InsertRoutingRules inserts a routing firewall rule
|
||||
InsertRoutingRules(pair RouterPair) error
|
||||
|
||||
// DeleteRouteRule deletes a routing rule
|
||||
DeleteRouteRule(rule Rule) error
|
||||
|
||||
// AddNatRule inserts a routing NAT rule
|
||||
AddNatRule(pair RouterPair) error
|
||||
|
||||
// RemoveNatRule removes a routing NAT rule
|
||||
RemoveNatRule(pair RouterPair) error
|
||||
|
||||
// SetLegacyManagement sets the legacy management mode
|
||||
SetLegacyManagement(legacy bool) error
|
||||
// RemoveRoutingRules removes a routing firewall rule
|
||||
RemoveRoutingRules(pair RouterPair) error
|
||||
|
||||
// Reset firewall to the default state
|
||||
Reset() error
|
||||
@@ -97,89 +83,6 @@ type Manager interface {
|
||||
Flush() error
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
return fmt.Sprintf(format, pair.ID, pair.Inverse)
|
||||
}
|
||||
|
||||
// LegacyManager defines the interface for legacy management operations
|
||||
type LegacyManager interface {
|
||||
RemoveAllLegacyRouteRules() error
|
||||
GetLegacyManagement() bool
|
||||
SetLegacyManagement(bool)
|
||||
}
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management
|
||||
func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
|
||||
oldLegacy := router.GetLegacyManagement()
|
||||
|
||||
if oldLegacy != isLegacy {
|
||||
router.SetLegacyManagement(isLegacy)
|
||||
log.Debugf("Set legacy management to %v", isLegacy)
|
||||
}
|
||||
|
||||
// client reconnected to a newer mgmt, we need to clean up the legacy rules
|
||||
if !isLegacy && oldLegacy {
|
||||
if err := router.RemoveAllLegacyRouteRules(); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rules: %v", err)
|
||||
}
|
||||
|
||||
log.Debugf("Legacy routing rules removed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateSetName generates a unique name for an ipset based on the given sources.
|
||||
func GenerateSetName(sources []netip.Prefix) string {
|
||||
// sort for consistent naming
|
||||
sortPrefixes(sources)
|
||||
|
||||
var sourcesStr strings.Builder
|
||||
for _, src := range sources {
|
||||
sourcesStr.WriteString(src.String())
|
||||
}
|
||||
|
||||
hash := sha256.Sum256([]byte(sourcesStr.String()))
|
||||
shortHash := hex.EncodeToString(hash[:])[:8]
|
||||
|
||||
return fmt.Sprintf("nb-%s", shortHash)
|
||||
}
|
||||
|
||||
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
|
||||
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
|
||||
if len(prefixes) == 0 {
|
||||
return prefixes
|
||||
}
|
||||
|
||||
merged := []netip.Prefix{prefixes[0]}
|
||||
for _, prefix := range prefixes[1:] {
|
||||
last := merged[len(merged)-1]
|
||||
if last.Contains(prefix.Addr()) {
|
||||
// If the current prefix is contained within the last merged prefix, skip it
|
||||
continue
|
||||
}
|
||||
if prefix.Contains(last.Addr()) {
|
||||
// If the current prefix contains the last merged prefix, replace it
|
||||
merged[len(merged)-1] = prefix
|
||||
} else {
|
||||
// Otherwise, add the current prefix to the merged list
|
||||
merged = append(merged, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
// sortPrefixes sorts the given slice of netip.Prefix in place.
|
||||
// It sorts first by IP address, then by prefix length (most specific to least specific).
|
||||
func sortPrefixes(prefixes []netip.Prefix) {
|
||||
sort.Slice(prefixes, func(i, j int) bool {
|
||||
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
|
||||
if addrCmp != 0 {
|
||||
return addrCmp < 0
|
||||
}
|
||||
|
||||
// If IP addresses are the same, compare prefix lengths (longer prefixes first)
|
||||
return prefixes[i].Bits() > prefixes[j].Bits()
|
||||
})
|
||||
func GenKey(format string, input string) string {
|
||||
return fmt.Sprintf(format, input)
|
||||
}
|
||||
|
||||
@@ -1,192 +0,0 @@
|
||||
package manager_test
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
func TestGenerateSetName(t *testing.T) {
|
||||
t.Run("Different orders result in same hash", func(t *testing.T) {
|
||||
prefixes1 := []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
}
|
||||
prefixes2 := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
|
||||
result1 := manager.GenerateSetName(prefixes1)
|
||||
result2 := manager.GenerateSetName(prefixes2)
|
||||
|
||||
if result1 != result2 {
|
||||
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Result format is correct", func(t *testing.T) {
|
||||
prefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
}
|
||||
|
||||
result := manager.GenerateSetName(prefixes)
|
||||
|
||||
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
|
||||
if err != nil {
|
||||
t.Fatalf("Error matching regex: %v", err)
|
||||
}
|
||||
if !matched {
|
||||
t.Errorf("Result format is incorrect: %s", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Empty input produces consistent result", func(t *testing.T) {
|
||||
result1 := manager.GenerateSetName([]netip.Prefix{})
|
||||
result2 := manager.GenerateSetName([]netip.Prefix{})
|
||||
|
||||
if result1 != result2 {
|
||||
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IPv4 and IPv6 mixing", func(t *testing.T) {
|
||||
prefixes1 := []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("2001:db8::/32"),
|
||||
}
|
||||
prefixes2 := []netip.Prefix{
|
||||
netip.MustParsePrefix("2001:db8::/32"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
|
||||
result1 := manager.GenerateSetName(prefixes1)
|
||||
result2 := manager.GenerateSetName(prefixes2)
|
||||
|
||||
if result1 != result2 {
|
||||
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMergeIPRanges(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []netip.Prefix
|
||||
expected []netip.Prefix
|
||||
}{
|
||||
{
|
||||
name: "Empty input",
|
||||
input: []netip.Prefix{},
|
||||
expected: []netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Single range",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Two non-overlapping ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "One range containing another",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "One range containing another (different order)",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Overlapping ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.1.128/25"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Overlapping ranges (different order)",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.128/25"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Multiple overlapping ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.2.0/24"),
|
||||
netip.MustParsePrefix("192.168.1.128/25"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Partially overlapping ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/23"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.2.0/25"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/23"),
|
||||
netip.MustParsePrefix("192.168.2.0/25"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "IPv6 ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("2001:db8::/32"),
|
||||
netip.MustParsePrefix("2001:db8:1::/48"),
|
||||
netip.MustParsePrefix("2001:db8:2::/48"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("2001:db8::/32"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := manager.MergeIPRanges(tt.input)
|
||||
if !reflect.DeepEqual(result, tt.expected) {
|
||||
t.Errorf("MergeIPRanges() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,26 +1,18 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type RouterPair struct {
|
||||
ID route.ID
|
||||
Source netip.Prefix
|
||||
Destination netip.Prefix
|
||||
ID string
|
||||
Source string
|
||||
Destination string
|
||||
Masquerade bool
|
||||
Inverse bool
|
||||
}
|
||||
|
||||
func GetInversePair(pair RouterPair) RouterPair {
|
||||
func GetInPair(pair RouterPair) RouterPair {
|
||||
return RouterPair{
|
||||
ID: pair.ID,
|
||||
// invert Source/Destination
|
||||
Source: pair.Destination,
|
||||
Destination: pair.Source,
|
||||
Masquerade: pair.Masquerade,
|
||||
Inverse: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -33,10 +33,9 @@ const (
|
||||
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
||||
)
|
||||
|
||||
const flushError = "flush: %w"
|
||||
|
||||
var (
|
||||
anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
postroutingMark = []byte{0xe4, 0x7, 0x0, 0x00}
|
||||
)
|
||||
|
||||
type AclManager struct {
|
||||
@@ -49,6 +48,7 @@ type AclManager struct {
|
||||
chainInputRules *nftables.Chain
|
||||
chainOutputRules *nftables.Chain
|
||||
chainFwFilter *nftables.Chain
|
||||
chainPrerouting *nftables.Chain
|
||||
|
||||
ipsetStore *ipsetStore
|
||||
rules map[string]*Rule
|
||||
@@ -64,7 +64,7 @@ type iFaceMapper interface {
|
||||
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) {
|
||||
// sConn is used for creating sets and adding/removing elements from them
|
||||
// it's differ then rConn (which does create new conn for each flush operation)
|
||||
// and is permanent. Using same connection for both type of operations
|
||||
// and is permanent. Using same connection for booth type of operations
|
||||
// overloads netlink with high amount of rules ( > 10000)
|
||||
sConn, err := nftables.New(nftables.AsLasting())
|
||||
if err != nil {
|
||||
@@ -90,11 +90,11 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainNa
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// AddPeerFiltering rule to the firewall
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *AclManager) AddPeerFiltering(
|
||||
func (m *AclManager) AddFiltering(
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
@@ -120,11 +120,20 @@ func (m *AclManager) AddPeerFiltering(
|
||||
}
|
||||
|
||||
newRules = append(newRules, ioRule)
|
||||
if !shouldAddToPrerouting(proto, dPort, direction) {
|
||||
return newRules, nil
|
||||
}
|
||||
|
||||
preroutingRule, err := m.addPreroutingFiltering(ipset, proto, dPort, ip)
|
||||
if err != nil {
|
||||
return newRules, err
|
||||
}
|
||||
newRules = append(newRules, preroutingRule)
|
||||
return newRules, nil
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *AclManager) DeleteRule(rule firewall.Rule) error {
|
||||
r, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
@@ -190,7 +199,8 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// createDefaultAllowRules creates default allow rules for the input and output chains
|
||||
// createDefaultAllowRules In case if the USP firewall manager can use the native firewall manager we must to create allow rules for
|
||||
// input and output chains
|
||||
func (m *AclManager) createDefaultAllowRules() error {
|
||||
expIn := []expr.Any{
|
||||
&expr.Payload{
|
||||
@@ -204,13 +214,13 @@ func (m *AclManager) createDefaultAllowRules() error {
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: []byte{0, 0, 0, 0},
|
||||
Xor: []byte{0, 0, 0, 0},
|
||||
Mask: []byte{0x00, 0x00, 0x00, 0x00},
|
||||
Xor: zeroXor,
|
||||
},
|
||||
// net address
|
||||
&expr.Cmp{
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
Data: []byte{0x00, 0x00, 0x00, 0x00},
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
@@ -236,13 +246,13 @@ func (m *AclManager) createDefaultAllowRules() error {
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: []byte{0, 0, 0, 0},
|
||||
Xor: []byte{0, 0, 0, 0},
|
||||
Mask: []byte{0x00, 0x00, 0x00, 0x00},
|
||||
Xor: zeroXor,
|
||||
},
|
||||
// net address
|
||||
&expr.Cmp{
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
Data: []byte{0x00, 0x00, 0x00, 0x00},
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
@@ -256,8 +266,10 @@ func (m *AclManager) createDefaultAllowRules() error {
|
||||
Exprs: expOut,
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
err := m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create default allow rules: %s", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -278,11 +290,15 @@ func (m *AclManager) Flush() error {
|
||||
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.chainPrerouting); err != nil {
|
||||
log.Errorf("failed to refresh rule handles IPv4 prerouting chain: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) {
|
||||
ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset)
|
||||
ruleId := generateRuleId(ip, sPort, dPort, direction, action, ipset)
|
||||
if r, ok := m.rules[ruleId]; ok {
|
||||
return &Rule{
|
||||
r.nftRule,
|
||||
@@ -292,7 +308,18 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
||||
}, nil
|
||||
}
|
||||
|
||||
var expressions []expr.Any
|
||||
ifaceKey := expr.MetaKeyIIFNAME
|
||||
if direction == firewall.RuleDirectionOUT {
|
||||
ifaceKey = expr.MetaKeyOIFNAME
|
||||
}
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
|
||||
if proto != firewall.ProtocolALL {
|
||||
expressions = append(expressions, &expr.Payload{
|
||||
@@ -302,15 +329,21 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
||||
Len: uint32(1),
|
||||
})
|
||||
|
||||
protoData, err := protoToInt(proto)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %v", err)
|
||||
var protoData []byte
|
||||
switch proto {
|
||||
case firewall.ProtocolTCP:
|
||||
protoData = []byte{unix.IPPROTO_TCP}
|
||||
case firewall.ProtocolUDP:
|
||||
protoData = []byte{unix.IPPROTO_UDP}
|
||||
case firewall.ProtocolICMP:
|
||||
protoData = []byte{unix.IPPROTO_ICMP}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported protocol: %s", proto)
|
||||
}
|
||||
|
||||
expressions = append(expressions, &expr.Cmp{
|
||||
Register: 1,
|
||||
Op: expr.CmpOpEq,
|
||||
Data: []byte{protoData},
|
||||
Data: protoData,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -399,9 +432,10 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
||||
} else {
|
||||
chain = m.chainOutputRules
|
||||
}
|
||||
nftRule := m.rConn.AddRule(&nftables.Rule{
|
||||
nftRule := m.rConn.InsertRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: chain,
|
||||
Position: 0,
|
||||
Exprs: expressions,
|
||||
UserData: userData,
|
||||
})
|
||||
@@ -419,13 +453,139 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (m *AclManager) addPreroutingFiltering(ipset *nftables.Set, proto firewall.Protocol, port *firewall.Port, ip net.IP) (*Rule, error) {
|
||||
var protoData []byte
|
||||
switch proto {
|
||||
case firewall.ProtocolTCP:
|
||||
protoData = []byte{unix.IPPROTO_TCP}
|
||||
case firewall.ProtocolUDP:
|
||||
protoData = []byte{unix.IPPROTO_UDP}
|
||||
case firewall.ProtocolICMP:
|
||||
protoData = []byte{unix.IPPROTO_ICMP}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported protocol: %s", proto)
|
||||
}
|
||||
|
||||
ruleId := generateRuleIdForMangle(ipset, ip, proto, port)
|
||||
if r, ok := m.rules[ruleId]; ok {
|
||||
return &Rule{
|
||||
r.nftRule,
|
||||
r.nftSet,
|
||||
r.ruleID,
|
||||
ip,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var ipExpression expr.Any
|
||||
// add individual IP for match if no ipset defined
|
||||
rawIP := ip.To4()
|
||||
if ipset == nil {
|
||||
ipExpression = &expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: rawIP,
|
||||
}
|
||||
} else {
|
||||
ipExpression = &expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: ipset.Name,
|
||||
SetID: ipset.ID,
|
||||
}
|
||||
}
|
||||
|
||||
expressions := []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
ipExpression,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 16,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: m.wgIface.Address().IP.To4(),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: uint32(9),
|
||||
Len: uint32(1),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Register: 1,
|
||||
Op: expr.CmpOpEq,
|
||||
Data: protoData,
|
||||
},
|
||||
}
|
||||
|
||||
if port != nil {
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: encodePort(*port),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
expressions = append(expressions,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: postroutingMark,
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
SourceRegister: true,
|
||||
Register: 1,
|
||||
},
|
||||
)
|
||||
|
||||
nftRule := m.rConn.InsertRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: m.chainPrerouting,
|
||||
Position: 0,
|
||||
Exprs: expressions,
|
||||
UserData: []byte(ruleId),
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush insert rule: %v", err)
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
nftRule: nftRule,
|
||||
nftSet: ipset,
|
||||
ruleID: ruleId,
|
||||
ip: ip,
|
||||
}
|
||||
|
||||
m.rules[ruleId] = rule
|
||||
if ipset != nil {
|
||||
m.ipsetStore.AddReferenceToIpset(ipset.Name)
|
||||
}
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (m *AclManager) createDefaultChains() (err error) {
|
||||
// chainNameInputRules
|
||||
chain := m.createChain(chainNameInputRules)
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
|
||||
return fmt.Errorf(flushError, err)
|
||||
return err
|
||||
}
|
||||
m.chainInputRules = chain
|
||||
|
||||
@@ -441,6 +601,9 @@ func (m *AclManager) createDefaultChains() (err error) {
|
||||
// netbird-acl-input-filter
|
||||
// type filter hook input priority filter; policy accept;
|
||||
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
|
||||
//netbird-acl-input-filter iifname "wt0" ip saddr 100.72.0.0/16 ip daddr != 100.72.0.0/16 accept
|
||||
m.addRouteAllowRule(chain, expr.MetaKeyIIFNAME)
|
||||
m.addFwdAllow(chain, expr.MetaKeyIIFNAME)
|
||||
m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
|
||||
m.addDropExpressions(chain, expr.MetaKeyIIFNAME)
|
||||
err = m.rConn.Flush()
|
||||
@@ -452,6 +615,7 @@ func (m *AclManager) createDefaultChains() (err error) {
|
||||
// netbird-acl-output-filter
|
||||
// type filter hook output priority filter; policy accept;
|
||||
chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput)
|
||||
m.addRouteAllowRule(chain, expr.MetaKeyOIFNAME)
|
||||
m.addFwdAllow(chain, expr.MetaKeyOIFNAME)
|
||||
m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules
|
||||
m.addDropExpressions(chain, expr.MetaKeyOIFNAME)
|
||||
@@ -463,15 +627,24 @@ func (m *AclManager) createDefaultChains() (err error) {
|
||||
|
||||
// netbird-acl-forward-filter
|
||||
m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
||||
m.addJumpRulesToRtForward() // to netbird-rt-fwd
|
||||
m.addJumpRulesToRtForward() // to
|
||||
m.addMarkAccept()
|
||||
m.addJumpRuleToInputChain() // to netbird-acl-input-rules
|
||||
m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME)
|
||||
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
|
||||
return fmt.Errorf(flushError, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// netbird-acl-output-filter
|
||||
// type filter hook output priority filter; policy accept;
|
||||
m.chainPrerouting = m.createPreroutingMangle()
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", m.chainPrerouting.Name, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -494,6 +667,59 @@ func (m *AclManager) addJumpRulesToRtForward() {
|
||||
Chain: m.chainFwFilter,
|
||||
Exprs: expressions,
|
||||
})
|
||||
|
||||
expressions = []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: m.routeingFwChainName,
|
||||
},
|
||||
}
|
||||
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: m.chainFwFilter,
|
||||
Exprs: expressions,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AclManager) addMarkAccept() {
|
||||
// oifname "wt0" meta mark 0x000007e4 accept
|
||||
// iifname "wt0" meta mark 0x000007e4 accept
|
||||
ifaces := []expr.MetaKey{expr.MetaKeyIIFNAME, expr.MetaKeyOIFNAME}
|
||||
for _, iface := range ifaces {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: iface, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: postroutingMark,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: m.chainFwFilter,
|
||||
Exprs: expressions,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (m *AclManager) createChain(name string) *nftables.Chain {
|
||||
@@ -503,9 +729,6 @@ func (m *AclManager) createChain(name string) *nftables.Chain {
|
||||
}
|
||||
|
||||
chain = m.rConn.AddChain(chain)
|
||||
|
||||
insertReturnTrafficRule(m.rConn, m.workTable, chain)
|
||||
|
||||
return chain
|
||||
}
|
||||
|
||||
@@ -523,6 +746,74 @@ func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.Cha
|
||||
return m.rConn.AddChain(chain)
|
||||
}
|
||||
|
||||
func (m *AclManager) createPreroutingMangle() *nftables.Chain {
|
||||
polAccept := nftables.ChainPolicyAccept
|
||||
chain := &nftables.Chain{
|
||||
Name: "netbird-acl-prerouting-filter",
|
||||
Table: m.workTable,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Policy: &polAccept,
|
||||
}
|
||||
|
||||
chain = m.rConn.AddChain(chain)
|
||||
|
||||
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 2,
|
||||
DestRegister: 2,
|
||||
Len: 4,
|
||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
||||
Mask: m.wgIface.Address().Network.Mask,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 2,
|
||||
Data: ip.Unmap().AsSlice(),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 16,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: m.wgIface.Address().IP.To4(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: postroutingMark,
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
SourceRegister: true,
|
||||
Register: 1,
|
||||
},
|
||||
}
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
})
|
||||
return chain
|
||||
}
|
||||
|
||||
func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
@@ -541,9 +832,101 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AclManager) addJumpRuleToInputChain() {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: m.chainInputRules.Name,
|
||||
},
|
||||
}
|
||||
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: m.chainFwFilter,
|
||||
Exprs: expressions,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AclManager) addRouteAllowRule(chain *nftables.Chain, netIfName expr.MetaKey) {
|
||||
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
|
||||
var srcOp, dstOp expr.CmpOp
|
||||
if netIfName == expr.MetaKeyIIFNAME {
|
||||
srcOp = expr.CmpOpEq
|
||||
dstOp = expr.CmpOpNeq
|
||||
} else {
|
||||
srcOp = expr.CmpOpNeq
|
||||
dstOp = expr.CmpOpEq
|
||||
}
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: netIfName, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 2,
|
||||
DestRegister: 2,
|
||||
Len: 4,
|
||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
||||
Mask: m.wgIface.Address().Network.Mask,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: srcOp,
|
||||
Register: 2,
|
||||
Data: ip.Unmap().AsSlice(),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 16,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 2,
|
||||
DestRegister: 2,
|
||||
Len: 4,
|
||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
||||
Mask: m.wgIface.Address().Network.Mask,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: dstOp,
|
||||
Register: 2,
|
||||
Data: ip.Unmap().AsSlice(),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
|
||||
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
|
||||
dstOp := expr.CmpOpNeq
|
||||
var srcOp, dstOp expr.CmpOp
|
||||
if iifname == expr.MetaKeyIIFNAME {
|
||||
srcOp = expr.CmpOpNeq
|
||||
dstOp = expr.CmpOpEq
|
||||
} else {
|
||||
srcOp = expr.CmpOpEq
|
||||
dstOp = expr.CmpOpNeq
|
||||
}
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: iifname, Register: 1},
|
||||
&expr.Cmp{
|
||||
@@ -551,6 +934,24 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 2,
|
||||
DestRegister: 2,
|
||||
Len: 4,
|
||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
||||
Mask: m.wgIface.Address().Network.Mask,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: srcOp,
|
||||
Register: 2,
|
||||
Data: ip.Unmap().AsSlice(),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
@@ -581,6 +982,7 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
|
||||
}
|
||||
|
||||
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
|
||||
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
&expr.Cmp{
|
||||
@@ -588,12 +990,47 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 2,
|
||||
DestRegister: 2,
|
||||
Len: 4,
|
||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
||||
Mask: m.wgIface.Address().Network.Mask,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 2,
|
||||
Data: ip.Unmap().AsSlice(),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 16,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 2,
|
||||
DestRegister: 2,
|
||||
Len: 4,
|
||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
||||
Mask: m.wgIface.Address().Network.Mask,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 2,
|
||||
Data: ip.Unmap().AsSlice(),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: to,
|
||||
},
|
||||
}
|
||||
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
@@ -695,7 +1132,7 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func generatePeerRuleId(
|
||||
func generateRuleId(
|
||||
ip net.IP,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
@@ -718,6 +1155,33 @@ func generatePeerRuleId(
|
||||
}
|
||||
return "set:" + ipset.Name + rulesetID
|
||||
}
|
||||
func generateRuleIdForMangle(ipset *nftables.Set, ip net.IP, proto firewall.Protocol, port *firewall.Port) string {
|
||||
// case of icmp port is empty
|
||||
var p string
|
||||
if port != nil {
|
||||
p = port.String()
|
||||
}
|
||||
if ipset != nil {
|
||||
return fmt.Sprintf("p:set:%s:%s:%v", ipset.Name, proto, p)
|
||||
} else {
|
||||
return fmt.Sprintf("p:ip:%s:%s:%v", ip.String(), proto, p)
|
||||
}
|
||||
}
|
||||
|
||||
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
|
||||
if proto == "all" {
|
||||
return false
|
||||
}
|
||||
|
||||
if direction != firewall.RuleDirectionIN {
|
||||
return false
|
||||
}
|
||||
|
||||
if dPort == nil && proto != firewall.ProtocolICMP {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func encodePort(port firewall.Port) []byte {
|
||||
bs := make([]byte, 2)
|
||||
@@ -727,19 +1191,6 @@ func encodePort(port firewall.Port) []byte {
|
||||
|
||||
func ifname(n string) []byte {
|
||||
b := make([]byte, 16)
|
||||
copy(b, n+"\x00")
|
||||
copy(b, []byte(n+"\x00"))
|
||||
return b
|
||||
}
|
||||
|
||||
func protoToInt(protocol firewall.Protocol) (uint8, error) {
|
||||
switch protocol {
|
||||
case firewall.ProtocolTCP:
|
||||
return unix.IPPROTO_TCP, nil
|
||||
case firewall.ProtocolUDP:
|
||||
return unix.IPPROTO_UDP, nil
|
||||
case firewall.ProtocolICMP:
|
||||
return unix.IPPROTO_ICMP, nil
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
|
||||
}
|
||||
|
||||
@@ -5,11 +5,9 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -17,11 +15,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// tableNameNetbird is the name of the table that is used for filtering by the Netbird client
|
||||
tableNameNetbird = "netbird"
|
||||
|
||||
tableNameFilter = "filter"
|
||||
chainNameInput = "INPUT"
|
||||
// tableName is the name of the table that is used for filtering by the Netbird client
|
||||
tableName = "netbird"
|
||||
)
|
||||
|
||||
// Manager of iptables firewall
|
||||
@@ -46,12 +41,12 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.router, err = newRouter(context, workTable, wgIface)
|
||||
m.router, err = newRouter(context, workTable)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
|
||||
m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -59,11 +54,11 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// AddPeerFiltering rule to the firewall
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
func (m *Manager) AddFiltering(
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
@@ -81,52 +76,33 @@ func (m *Manager) AddPeerFiltering(
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||
}
|
||||
|
||||
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
|
||||
return m.aclManager.AddFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if !destination.Addr().Is4() {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.aclManager.DeletePeerRule(rule)
|
||||
}
|
||||
|
||||
// DeleteRouteRule deletes a routing rule
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.DeleteRouteRule(rule)
|
||||
return m.aclManager.DeleteRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddNatRule(pair)
|
||||
return m.router.AddRoutingRules(pair)
|
||||
}
|
||||
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveNatRule(pair)
|
||||
return m.router.RemoveRoutingRules(pair)
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
@@ -150,7 +126,7 @@ func (m *Manager) AllowNetbird() error {
|
||||
|
||||
var chain *nftables.Chain
|
||||
for _, c := range chains {
|
||||
if c.Table.Name == tableNameFilter && c.Name == chainNameForward {
|
||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
||||
chain = c
|
||||
break
|
||||
}
|
||||
@@ -181,27 +157,6 @@ func (m *Manager) AllowNetbird() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
oldLegacy := m.router.legacyManagement
|
||||
|
||||
if oldLegacy != isLegacy {
|
||||
m.router.legacyManagement = isLegacy
|
||||
log.Debugf("Set legacy management to %v", isLegacy)
|
||||
}
|
||||
|
||||
// client reconnected to a newer mgmt, we need to cleanup the legacy rules
|
||||
if !isLegacy && oldLegacy {
|
||||
if err := m.router.RemoveAllLegacyRouteRules(); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rules: %v", err)
|
||||
}
|
||||
|
||||
log.Debugf("Legacy routing rules removed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
@@ -230,16 +185,14 @@ func (m *Manager) Reset() error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.router.Reset(); err != nil {
|
||||
return fmt.Errorf("reset forward rules: %v", err)
|
||||
}
|
||||
m.router.ResetForwardRules()
|
||||
|
||||
tables, err := m.rConn.ListTables()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
for _, t := range tables {
|
||||
if t.Name == tableNameNetbird {
|
||||
if t.Name == tableName {
|
||||
m.rConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
@@ -265,12 +218,12 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == tableNameNetbird {
|
||||
if t.Name == tableName {
|
||||
m.rConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
|
||||
err = m.rConn.Flush()
|
||||
return table, err
|
||||
}
|
||||
@@ -286,7 +239,9 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: []byte(allowNetbirdInputRuleID),
|
||||
}
|
||||
@@ -296,7 +251,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
||||
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
|
||||
ifName := ifname(m.wgIface.Name())
|
||||
for _, rule := range existedRules {
|
||||
if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput {
|
||||
if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" {
|
||||
if len(rule.Exprs) < 4 {
|
||||
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
|
||||
continue
|
||||
@@ -310,33 +265,3 @@ func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftable
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) {
|
||||
rule := &nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
conn.InsertRule(rule)
|
||||
}
|
||||
|
||||
@@ -9,30 +9,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
var ifaceMock = &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMock struct {
|
||||
NameFunc func() string
|
||||
@@ -56,9 +40,23 @@ func (i *iFaceMock) Address() iface.WGAddress {
|
||||
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
||||
|
||||
func TestNftablesManager(t *testing.T) {
|
||||
mock := &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(context.Background(), ifaceMock)
|
||||
manager, err := Create(context.Background(), mock)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
@@ -72,7 +70,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
|
||||
testClient := &nftables.Conn{}
|
||||
|
||||
rule, err := manager.AddPeerFiltering(
|
||||
rule, err := manager.AddFiltering(
|
||||
ip,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
@@ -90,34 +88,17 @@ func TestNftablesManager(t *testing.T) {
|
||||
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
|
||||
require.Len(t, rules, 2, "expected 2 rules")
|
||||
|
||||
expectedExprs1 := []expr.Any{
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions")
|
||||
require.Len(t, rules, 1, "expected 1 rules")
|
||||
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
expectedExprs2 := []expr.Any{
|
||||
expectedExprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname("lo"),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
@@ -153,10 +134,10 @@ func TestNftablesManager(t *testing.T) {
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||
}
|
||||
require.ElementsMatch(t, rules[1].Exprs, expectedExprs2, "expected the same expressions")
|
||||
require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
|
||||
|
||||
for _, r := range rule {
|
||||
err = manager.DeletePeerRule(r)
|
||||
err = manager.DeleteRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
|
||||
@@ -165,8 +146,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
|
||||
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
// established rule remains
|
||||
require.Len(t, rules, 1, "expected 1 rules after deletion")
|
||||
require.Len(t, rules, 0, "expected 0 rules after deletion")
|
||||
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "failed to reset")
|
||||
@@ -207,9 +187,9 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
|
||||
431
client/firewall/nftables/route_linux.go
Normal file
431
client/firewall/nftables/route_linux.go
Normal file
@@ -0,0 +1,431 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
chainNameRouteingFw = "netbird-rt-fwd"
|
||||
chainNameRoutingNat = "netbird-rt-nat"
|
||||
|
||||
userDataAcceptForwardRuleSrc = "frwacceptsrc"
|
||||
userDataAcceptForwardRuleDst = "frwacceptdst"
|
||||
|
||||
loopbackInterface = "lo\x00"
|
||||
)
|
||||
|
||||
// some presets for building nftable rules
|
||||
var (
|
||||
zeroXor = binaryutil.NativeEndian.PutUint32(0)
|
||||
|
||||
exprCounterAccept = []expr.Any{
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
|
||||
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
|
||||
)
|
||||
|
||||
type router struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
conn *nftables.Conn
|
||||
workTable *nftables.Table
|
||||
filterTable *nftables.Table
|
||||
chains map[string]*nftables.Chain
|
||||
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
|
||||
rules map[string]*nftables.Rule
|
||||
isDefaultFwdRulesEnabled bool
|
||||
}
|
||||
|
||||
func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
|
||||
r := &router{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
workTable: workTable,
|
||||
chains: make(map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
}
|
||||
|
||||
var err error
|
||||
r.filterTable, err = r.loadFilterTable()
|
||||
if err != nil {
|
||||
if errors.Is(err, errFilterTableNotFound) {
|
||||
log.Warnf("table 'filter' not found for forward rules")
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = r.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
err = r.createContainers()
|
||||
if err != nil {
|
||||
log.Errorf("failed to create containers for route: %s", err)
|
||||
}
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (r *router) RouteingFwChainName() string {
|
||||
return chainNameRouteingFw
|
||||
}
|
||||
|
||||
// ResetForwardRules cleans existing nftables default forward rules from the system
|
||||
func (r *router) ResetForwardRules() {
|
||||
err := r.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("failed to reset forward rules: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
if table.Name == "filter" {
|
||||
return table, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errFilterTableNotFound
|
||||
}
|
||||
|
||||
func (r *router) createContainers() error {
|
||||
|
||||
r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRouteingFw,
|
||||
Table: r.workTable,
|
||||
})
|
||||
|
||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingNat,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPostrouting,
|
||||
Priority: nftables.ChainPriorityNATSource - 1,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
// Add RETURN rule for loopback interface
|
||||
loRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte(loopbackInterface),
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictReturn},
|
||||
},
|
||||
}
|
||||
r.conn.InsertRule(loRule)
|
||||
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain
|
||||
func (r *router) AddRoutingRules(pair manager.RouterPair) error {
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if pair.Masquerade {
|
||||
err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if r.filterTable != nil && !r.isDefaultFwdRulesEnabled {
|
||||
log.Debugf("add default accept forward rule")
|
||||
r.acceptForwardRule(pair.Source)
|
||||
}
|
||||
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.Destination, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// addRoutingRule inserts a nftable rule to the conn client flush queue
|
||||
func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
|
||||
var expression []expr.Any
|
||||
if isNat {
|
||||
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) // nolint:gocritic
|
||||
} else {
|
||||
expression = append(sourceExp, append(destExp, exprCounterAccept...)...) // nolint:gocritic
|
||||
}
|
||||
|
||||
ruleKey := manager.GenKey(format, pair.ID)
|
||||
|
||||
_, exists := r.rules[ruleKey]
|
||||
if exists {
|
||||
err := r.removeRoutingRule(format, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainName],
|
||||
Exprs: expression,
|
||||
UserData: []byte(ruleKey),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) acceptForwardRule(sourceNetwork string) {
|
||||
src := generateCIDRMatcherExpressions(true, sourceNetwork)
|
||||
dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0")
|
||||
|
||||
var exprs []expr.Any
|
||||
exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic
|
||||
Kind: expr.VerdictAccept,
|
||||
})...)
|
||||
|
||||
rule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: "FORWARD",
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: exprs,
|
||||
UserData: []byte(userDataAcceptForwardRuleSrc),
|
||||
}
|
||||
|
||||
r.conn.AddRule(rule)
|
||||
|
||||
src = generateCIDRMatcherExpressions(true, "0.0.0.0/0")
|
||||
dst = generateCIDRMatcherExpressions(false, sourceNetwork)
|
||||
|
||||
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
|
||||
Kind: expr.VerdictAccept,
|
||||
})...)
|
||||
|
||||
rule = &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: "FORWARD",
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: exprs,
|
||||
UserData: []byte(userDataAcceptForwardRuleDst),
|
||||
}
|
||||
r.conn.AddRule(rule)
|
||||
r.isDefaultFwdRulesEnabled = true
|
||||
}
|
||||
|
||||
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
|
||||
func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.removeRoutingRule(manager.ForwardingFormat, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.removeRoutingRule(manager.InForwardingFormat, manager.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.removeRoutingRule(manager.NatFormat, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.removeRoutingRule(manager.InNatFormat, manager.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(r.rules) == 0 {
|
||||
err := r.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
|
||||
}
|
||||
log.Debugf("nftables: removed rules for %s", pair.Destination)
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
|
||||
func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error {
|
||||
ruleKey := manager.GenKey(format, pair.ID)
|
||||
|
||||
rule, found := r.rules[ruleKey]
|
||||
if found {
|
||||
ruleType := "forwarding"
|
||||
if rule.Chain.Type == nftables.ChainTypeNAT {
|
||||
ruleType = "nat"
|
||||
}
|
||||
|
||||
err := r.conn.DelRule(rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
||||
// duplicates and to get missing attributes that we don't have when adding new rules
|
||||
func (r *router) refreshRulesMap() error {
|
||||
for _, chain := range r.chains {
|
||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to list rules: %v", err)
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
r.rules[string(rule.UserData)] = rule
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) cleanUpDefaultForwardRules() error {
|
||||
if r.filterTable == nil {
|
||||
r.isDefaultFwdRulesEnabled = false
|
||||
return nil
|
||||
}
|
||||
|
||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var rules []*nftables.Rule
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name != r.filterTable.Name {
|
||||
continue
|
||||
}
|
||||
if chain.Name != "FORWARD" {
|
||||
continue
|
||||
}
|
||||
|
||||
rules, err = r.conn.GetRules(r.filterTable, chain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleSrc)) || bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleDst)) {
|
||||
err := r.conn.DelRule(rule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
r.isDefaultFwdRulesEnabled = false
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
||||
func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any {
|
||||
ip, network, _ := net.ParseCIDR(cidr)
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
|
||||
var offSet uint32
|
||||
if source {
|
||||
offSet = 12 // src offset
|
||||
} else {
|
||||
offSet = 16 // dst offset
|
||||
}
|
||||
|
||||
return []expr.Any{
|
||||
// fetch src add
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: offSet,
|
||||
Len: 4,
|
||||
},
|
||||
// net mask
|
||||
&expr.Bitwise{
|
||||
DestRegister: 1,
|
||||
SourceRegister: 1,
|
||||
Len: 4,
|
||||
Mask: network.Mask,
|
||||
Xor: zeroXor,
|
||||
},
|
||||
// net address
|
||||
&expr.Cmp{
|
||||
Register: 1,
|
||||
Data: add.AsSlice(),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,798 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
)
|
||||
|
||||
const (
|
||||
chainNameRoutingFw = "netbird-rt-fwd"
|
||||
chainNameRoutingNat = "netbird-rt-nat"
|
||||
chainNameForward = "FORWARD"
|
||||
|
||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||
)
|
||||
|
||||
const refreshRulesMapError = "refresh rules map: %w"
|
||||
|
||||
var (
|
||||
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
|
||||
)
|
||||
|
||||
type router struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
conn *nftables.Conn
|
||||
workTable *nftables.Table
|
||||
filterTable *nftables.Table
|
||||
chains map[string]*nftables.Chain
|
||||
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
|
||||
rules map[string]*nftables.Rule
|
||||
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
|
||||
|
||||
wgIface iFaceMapper
|
||||
legacyManagement bool
|
||||
}
|
||||
|
||||
func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
|
||||
r := &router{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
workTable: workTable,
|
||||
chains: make(map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
wgIface: wgIface,
|
||||
}
|
||||
|
||||
r.ipsetCounter = refcounter.New(
|
||||
r.createIpSet,
|
||||
r.deleteIpSet,
|
||||
)
|
||||
|
||||
var err error
|
||||
r.filterTable, err = r.loadFilterTable()
|
||||
if err != nil {
|
||||
if errors.Is(err, errFilterTableNotFound) {
|
||||
log.Warnf("table 'filter' not found for forward rules")
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = r.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
err = r.createContainers()
|
||||
if err != nil {
|
||||
log.Errorf("failed to create containers for route: %s", err)
|
||||
}
|
||||
return r, err
|
||||
}
|
||||
|
||||
// Reset cleans existing nftables default forward rules from the system
|
||||
func (r *router) Reset() error {
|
||||
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
||||
r.ipsetCounter.Clear()
|
||||
|
||||
return r.cleanUpDefaultForwardRules()
|
||||
}
|
||||
|
||||
func (r *router) cleanUpDefaultForwardRules() error {
|
||||
if r.filterTable == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %v", err)
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
|
||||
continue
|
||||
}
|
||||
|
||||
rules, err := r.conn.GetRules(r.filterTable, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get rules: %v", err)
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
if table.Name == "filter" {
|
||||
return table, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errFilterTableNotFound
|
||||
}
|
||||
|
||||
func (r *router) createContainers() error {
|
||||
|
||||
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingFw,
|
||||
Table: r.workTable,
|
||||
})
|
||||
|
||||
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||
|
||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingNat,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPostrouting,
|
||||
Priority: nftables.ChainPriorityNATSource - 1,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
r.acceptForwardRules()
|
||||
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddRouteFiltering appends a nftables rule to the routing chain
|
||||
func (r *router) AddRouteFiltering(
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||
return ruleKey, nil
|
||||
}
|
||||
|
||||
chain := r.chains[chainNameRoutingFw]
|
||||
var exprs []expr.Any
|
||||
|
||||
switch {
|
||||
case len(sources) == 1 && sources[0].Bits() == 0:
|
||||
// If it's 0.0.0.0/0, we don't need to add any source matching
|
||||
case len(sources) == 1:
|
||||
// If there's only one source, we can use it directly
|
||||
exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...)
|
||||
default:
|
||||
// If there are multiple sources, create or get an ipset
|
||||
var err error
|
||||
exprs, err = r.getIpSetExprs(sources, exprs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get ipset expressions: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle destination
|
||||
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
|
||||
|
||||
// Handle protocol
|
||||
if proto != firewall.ProtocolALL {
|
||||
protoNum, err := protoToInt(proto)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
exprs = append(exprs, &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1})
|
||||
exprs = append(exprs, &expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{protoNum},
|
||||
})
|
||||
|
||||
exprs = append(exprs, applyPort(sPort, true)...)
|
||||
exprs = append(exprs, applyPort(dPort, false)...)
|
||||
}
|
||||
|
||||
exprs = append(exprs, &expr.Counter{})
|
||||
|
||||
var verdict expr.VerdictKind
|
||||
if action == firewall.ActionAccept {
|
||||
verdict = expr.VerdictAccept
|
||||
} else {
|
||||
verdict = expr.VerdictDrop
|
||||
}
|
||||
exprs = append(exprs, &expr.Verdict{Kind: verdict})
|
||||
|
||||
rule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: chain,
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleKey),
|
||||
}
|
||||
|
||||
r.rules[string(ruleKey)] = r.conn.AddRule(rule)
|
||||
|
||||
return ruleKey, r.conn.Flush()
|
||||
}
|
||||
|
||||
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
|
||||
setName := firewall.GenerateSetName(sources)
|
||||
ref, err := r.ipsetCounter.Increment(setName, sources)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset for sources: %w", err)
|
||||
}
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: ref.Out.Name,
|
||||
SetID: ref.Out.ID,
|
||||
},
|
||||
)
|
||||
return exprs, nil
|
||||
}
|
||||
|
||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
ruleKey := rule.GetRuleID()
|
||||
nftRule, exists := r.rules[ruleKey]
|
||||
if !exists {
|
||||
log.Debugf("route rule %s not found", ruleKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
setName := r.findSetNameInRule(nftRule)
|
||||
|
||||
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
||||
return fmt.Errorf("delete: %w", err)
|
||||
}
|
||||
|
||||
if setName != "" {
|
||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||
return fmt.Errorf("decrement ipset reference: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
|
||||
// overlapping prefixes will result in an error, so we need to merge them
|
||||
sources = firewall.MergeIPRanges(sources)
|
||||
|
||||
set := &nftables.Set{
|
||||
Name: setName,
|
||||
Table: r.workTable,
|
||||
// required for prefixes
|
||||
Interval: true,
|
||||
KeyType: nftables.TypeIPAddr,
|
||||
}
|
||||
|
||||
var elements []nftables.SetElement
|
||||
for _, prefix := range sources {
|
||||
// TODO: Implement IPv6 support
|
||||
if prefix.Addr().Is6() {
|
||||
log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||
continue
|
||||
}
|
||||
|
||||
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
|
||||
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
|
||||
firstIP := prefix.Addr()
|
||||
lastIP := calculateLastIP(prefix).Next()
|
||||
|
||||
elements = append(elements,
|
||||
// the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247
|
||||
// nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true},
|
||||
nftables.SetElement{Key: firstIP.AsSlice()},
|
||||
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
|
||||
)
|
||||
}
|
||||
|
||||
if err := r.conn.AddSet(set, elements); err != nil {
|
||||
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush error: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
|
||||
|
||||
return set, nil
|
||||
}
|
||||
|
||||
// calculateLastIP determines the last IP in a given prefix.
|
||||
func calculateLastIP(prefix netip.Prefix) netip.Addr {
|
||||
hostMask := ^uint32(0) >> prefix.Masked().Bits()
|
||||
lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask
|
||||
|
||||
return netip.AddrFrom4(uint32ToBytes(lastIP))
|
||||
}
|
||||
|
||||
// Utility function to convert netip.Addr to uint32.
|
||||
func uint32FromNetipAddr(addr netip.Addr) uint32 {
|
||||
b := addr.As4()
|
||||
return binary.BigEndian.Uint32(b[:])
|
||||
}
|
||||
|
||||
// Utility function to convert uint32 to a netip-compatible byte slice.
|
||||
func uint32ToBytes(ip uint32) [4]byte {
|
||||
var b [4]byte
|
||||
binary.BigEndian.PutUint32(b[:], ip)
|
||||
return b
|
||||
}
|
||||
|
||||
func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
|
||||
r.conn.DelSet(set)
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
log.Debugf("Deleted unused ipset %s", setName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) findSetNameInRule(rule *nftables.Rule) string {
|
||||
for _, e := range rule.Exprs {
|
||||
if lookup, ok := e.(*expr.Lookup); ok {
|
||||
return lookup.SetName
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete rule %s: %w", ruleKey, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
log.Debugf("removed route rule %s", ruleKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddNatRule appends a nftables rule pair to the nat chain
|
||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
if r.legacyManagement {
|
||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("add legacy routing rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if pair.Masquerade {
|
||||
if err := r.addNatRule(pair); err != nil {
|
||||
return fmt.Errorf("add nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("add inverse nat rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addNatRule inserts a nftables rule to the conn client flush queue
|
||||
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
|
||||
dir := expr.MetaKeyIIFNAME
|
||||
if pair.Inverse {
|
||||
dir = expr.MetaKeyOIFNAME
|
||||
}
|
||||
|
||||
intf := ifname(r.wgIface.Name())
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: dir,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: intf,
|
||||
},
|
||||
}
|
||||
|
||||
exprs = append(exprs, sourceExp...)
|
||||
exprs = append(exprs, destExp...)
|
||||
exprs = append(exprs,
|
||||
&expr.Counter{}, &expr.Masq{},
|
||||
)
|
||||
|
||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||
|
||||
if _, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove routing rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleKey),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
|
||||
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
|
||||
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
|
||||
if _, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingFw],
|
||||
Exprs: expression,
|
||||
UserData: []byte(ruleKey),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
} else {
|
||||
log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLegacyManagement returns the route manager's legacy management mode
|
||||
func (r *router) GetLegacyManagement() bool {
|
||||
return r.legacyManagement
|
||||
}
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management mode
|
||||
func (r *router) SetLegacyManagement(isLegacy bool) {
|
||||
r.legacyManagement = isLegacy
|
||||
}
|
||||
|
||||
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
|
||||
func (r *router) RemoveAllLegacyRouteRules() error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for k, rule := range r.rules {
|
||||
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
|
||||
continue
|
||||
}
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// acceptForwardRules adds iif/oif rules in the filter table/forward chain to make sure
|
||||
// that our traffic is not dropped by existing rules there.
|
||||
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||
func (r *router) acceptForwardRules() {
|
||||
if r.filterTable == nil {
|
||||
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
|
||||
return
|
||||
}
|
||||
|
||||
intf := ifname(r.wgIface.Name())
|
||||
|
||||
// Rule for incoming interface (iif) with counter
|
||||
iifRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: "FORWARD",
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: intf,
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
},
|
||||
UserData: []byte(userDataAcceptForwardRuleIif),
|
||||
}
|
||||
r.conn.InsertRule(iifRule)
|
||||
|
||||
// Rule for outgoing interface (oif) with counter
|
||||
oifRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: "FORWARD",
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: intf,
|
||||
},
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 2,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 2,
|
||||
DestRegister: 2,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 2,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
},
|
||||
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||
}
|
||||
|
||||
r.conn.InsertRule(oifRule)
|
||||
}
|
||||
|
||||
// RemoveNatRule removes a nftables rule pair from nat chains
|
||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("nftables: removed rules for %s", pair.Destination)
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeNatRule adds a nftables rule to the removal queue and deletes it from the rules map
|
||||
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
err := r.conn.DelRule(rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove nat rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("nftables: removed nat rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
} else {
|
||||
log.Debugf("nftables: nat rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
||||
// duplicates and to get missing attributes that we don't have when adding new rules
|
||||
func (r *router) refreshRulesMap() error {
|
||||
for _, chain := range r.chains {
|
||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to list rules: %v", err)
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
r.rules[string(rule.UserData)] = rule
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
||||
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
|
||||
var offset uint32
|
||||
if source {
|
||||
offset = 12 // src offset
|
||||
} else {
|
||||
offset = 16 // dst offset
|
||||
}
|
||||
|
||||
ones := prefix.Bits()
|
||||
// 0.0.0.0/0 doesn't need extra expressions
|
||||
if ones == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
mask := net.CIDRMask(ones, 32)
|
||||
|
||||
return []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: offset,
|
||||
Len: 4,
|
||||
},
|
||||
// netmask
|
||||
&expr.Bitwise{
|
||||
DestRegister: 1,
|
||||
SourceRegister: 1,
|
||||
Len: 4,
|
||||
Mask: mask,
|
||||
Xor: []byte{0, 0, 0, 0},
|
||||
},
|
||||
// net address
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: prefix.Masked().Addr().AsSlice(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
||||
if port == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var exprs []expr.Any
|
||||
|
||||
offset := uint32(2) // Default offset for destination port
|
||||
if isSource {
|
||||
offset = 0 // Offset for source port
|
||||
}
|
||||
|
||||
exprs = append(exprs, &expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: offset,
|
||||
Len: 2,
|
||||
})
|
||||
|
||||
if port.IsRange && len(port.Values) == 2 {
|
||||
// Handle port range
|
||||
exprs = append(exprs,
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpGte,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[0])),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpLte,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[1])),
|
||||
},
|
||||
)
|
||||
} else {
|
||||
// Handle single port or multiple ports
|
||||
for i, p := range port.Values {
|
||||
if i > 0 {
|
||||
// Add a bitwise OR operation between port checks
|
||||
exprs = append(exprs, &expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: []byte{0x00, 0x00, 0xff, 0xff},
|
||||
Xor: []byte{0x00, 0x00, 0x00, 0x00},
|
||||
})
|
||||
}
|
||||
exprs = append(exprs, &expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(p)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return exprs
|
||||
}
|
||||
@@ -4,15 +4,11 @@ package nftables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -28,50 +24,56 @@ const (
|
||||
NFTABLES
|
||||
)
|
||||
|
||||
func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this OS")
|
||||
}
|
||||
|
||||
table, err := createWorkTable()
|
||||
require.NoError(t, err, "Failed to create work table")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
for _, testCase := range test.InsertRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
manager, err := newRouter(context.TODO(), table, ifaceMock)
|
||||
manager, err := newRouter(context.TODO(), table)
|
||||
require.NoError(t, err, "failed to create router")
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
defer func(manager *router) {
|
||||
require.NoError(t, manager.Reset(), "failed to reset rules")
|
||||
}(manager)
|
||||
defer manager.ResetForwardRules()
|
||||
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.AddNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "pair should be inserted")
|
||||
err = manager.AddRoutingRules(testCase.InputPair)
|
||||
defer func() {
|
||||
_ = manager.RemoveRoutingRules(testCase.InputPair)
|
||||
}()
|
||||
require.NoError(t, err, "forwarding pair should be inserted")
|
||||
|
||||
defer func(manager *router, pair firewall.RouterPair) {
|
||||
require.NoError(t, manager.RemoveNatRule(pair), "failed to remove rule")
|
||||
}(manager, testCase.InputPair)
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
|
||||
fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||
|
||||
found := 0
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
|
||||
found = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
|
||||
if testCase.InputPair.Masquerade {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
|
||||
testingExpression = append(testingExpression,
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(ifaceMock.Name()),
|
||||
},
|
||||
)
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||
found := 0
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
@@ -86,20 +88,27 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
}
|
||||
|
||||
if testCase.InputPair.Masquerade {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
|
||||
testingExpression = append(testingExpression,
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(ifaceMock.Name()),
|
||||
},
|
||||
)
|
||||
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
|
||||
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
testingExpression = append(sourceExp, destExp...) //nolint:gocritic
|
||||
inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||
|
||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
||||
found = 0
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
|
||||
found = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
|
||||
if testCase.InputPair.Masquerade {
|
||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||
found := 0
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
@@ -113,37 +122,45 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
}
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this OS")
|
||||
}
|
||||
|
||||
table, err := createWorkTable()
|
||||
require.NoError(t, err, "Failed to create work table")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
for _, testCase := range test.RemoveRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
manager, err := newRouter(context.TODO(), table, ifaceMock)
|
||||
manager, err := newRouter(context.TODO(), table)
|
||||
require.NoError(t, err, "failed to create router")
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
defer func(manager *router) {
|
||||
require.NoError(t, manager.Reset(), "failed to reset rules")
|
||||
}(manager)
|
||||
defer manager.ResetForwardRules()
|
||||
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
|
||||
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
|
||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
Chain: manager.chains[chainNameRouteingFw],
|
||||
Exprs: forwardExp,
|
||||
UserData: []byte(forwardRuleKey),
|
||||
})
|
||||
|
||||
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||
|
||||
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
@@ -152,11 +169,20 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
UserData: []byte(natRuleKey),
|
||||
})
|
||||
|
||||
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInversePair(testCase.InputPair).Source)
|
||||
destExp = generateCIDRMatcherExpressions(false, firewall.GetInversePair(testCase.InputPair).Destination)
|
||||
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
|
||||
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
|
||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
Chain: manager.chains[chainNameRouteingFw],
|
||||
Exprs: forwardExp,
|
||||
UserData: []byte(inForwardRuleKey),
|
||||
})
|
||||
|
||||
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||
|
||||
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
@@ -168,10 +194,9 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
err = nftablesTestingClient.Flush()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
manager.ResetForwardRules()
|
||||
|
||||
err = manager.RemoveNatRule(testCase.InputPair)
|
||||
err = manager.RemoveRoutingRules(testCase.InputPair)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
for _, chain := range manager.chains {
|
||||
@@ -179,7 +204,9 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
|
||||
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
|
||||
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
|
||||
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
|
||||
}
|
||||
}
|
||||
@@ -188,468 +215,6 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
workTable, err := createWorkTable()
|
||||
require.NoError(t, err, "Failed to create work table")
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(context.Background(), workTable, ifaceMock)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
|
||||
defer func(r *router) {
|
||||
require.NoError(t, r.Reset(), "Failed to reset rules")
|
||||
}(r)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sources []netip.Prefix
|
||||
destination netip.Prefix
|
||||
proto firewall.Protocol
|
||||
sPort *firewall.Port
|
||||
dPort *firewall.Port
|
||||
direction firewall.RuleDirection
|
||||
action firewall.Action
|
||||
expectSet bool
|
||||
}{
|
||||
{
|
||||
name: "Basic TCP rule with single source",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: nil,
|
||||
dPort: &firewall.Port{Values: []int{80}},
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "UDP rule with multiple sources",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
proto: firewall.ProtocolUDP,
|
||||
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionDrop,
|
||||
expectSet: true,
|
||||
},
|
||||
{
|
||||
name: "All protocols rule",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
proto: firewall.ProtocolALL,
|
||||
sPort: nil,
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "ICMP rule",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
proto: firewall.ProtocolICMP,
|
||||
sPort: nil,
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "TCP rule with multiple source ports",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
||||
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "UDP rule with single IP and port range",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
proto: firewall.ProtocolUDP,
|
||||
sPort: nil,
|
||||
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionDrop,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "TCP rule with source and destination ports",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
||||
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
|
||||
dPort: &firewall.Port{Values: []int{22}},
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "Drop all incoming traffic",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
destination: netip.MustParsePrefix("192.168.0.0/24"),
|
||||
proto: firewall.ProtocolALL,
|
||||
sPort: nil,
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionDrop,
|
||||
expectSet: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddRouteFiltering failed")
|
||||
|
||||
// Check if the rule is in the internal map
|
||||
rule, ok := r.rules[ruleKey.GetRuleID()]
|
||||
assert.True(t, ok, "Rule not found in internal map")
|
||||
|
||||
t.Log("Internal rule expressions:")
|
||||
for i, expr := range rule.Exprs {
|
||||
t.Logf(" [%d] %T: %+v", i, expr, expr)
|
||||
}
|
||||
|
||||
// Verify internal rule content
|
||||
verifyRule(t, rule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
|
||||
|
||||
// Check if the rule exists in nftables and verify its content
|
||||
rules, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw])
|
||||
require.NoError(t, err, "Failed to get rules from nftables")
|
||||
|
||||
var nftRule *nftables.Rule
|
||||
for _, rule := range rules {
|
||||
if string(rule.UserData) == ruleKey.GetRuleID() {
|
||||
nftRule = rule
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, nftRule, "Rule not found in nftables")
|
||||
t.Log("Actual nftables rule expressions:")
|
||||
for i, expr := range nftRule.Exprs {
|
||||
t.Logf(" [%d] %T: %+v", i, expr, expr)
|
||||
}
|
||||
|
||||
// Verify actual nftables rule content
|
||||
verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
|
||||
|
||||
// Clean up
|
||||
err = r.DeleteRouteRule(ruleKey)
|
||||
require.NoError(t, err, "Failed to delete rule")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNftablesCreateIpSet(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
workTable, err := createWorkTable()
|
||||
require.NoError(t, err, "Failed to create work table")
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(context.Background(), workTable, ifaceMock)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
|
||||
defer func() {
|
||||
require.NoError(t, r.Reset(), "Failed to reset router")
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sources []netip.Prefix
|
||||
expected []netip.Prefix
|
||||
}{
|
||||
{
|
||||
name: "Single IP",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
|
||||
},
|
||||
{
|
||||
name: "Multiple IPs",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.1/32"),
|
||||
netip.MustParsePrefix("10.0.0.1/32"),
|
||||
netip.MustParsePrefix("172.16.0.1/32"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Single Subnet",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
|
||||
},
|
||||
{
|
||||
name: "Multiple Subnets with Various Prefix Lengths",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("203.0.113.0/26"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Mix of Single IPs and Subnets in Different Positions",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.1/32"),
|
||||
netip.MustParsePrefix("10.0.0.0/16"),
|
||||
netip.MustParsePrefix("172.16.0.1/32"),
|
||||
netip.MustParsePrefix("203.0.113.0/24"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Overlapping IPs/Subnets",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("10.0.0.0/16"),
|
||||
netip.MustParsePrefix("10.0.0.1/32"),
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.1.1/32"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Add this helper function inside TestNftablesCreateIpSet
|
||||
printNftSets := func() {
|
||||
cmd := exec.Command("nft", "list", "sets")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Logf("Failed to run 'nft list sets': %v", err)
|
||||
} else {
|
||||
t.Logf("Current nft sets:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setName := firewall.GenerateSetName(tt.sources)
|
||||
set, err := r.createIpSet(setName, tt.sources)
|
||||
if err != nil {
|
||||
t.Logf("Failed to create IP set: %v", err)
|
||||
printNftSets()
|
||||
require.NoError(t, err, "Failed to create IP set")
|
||||
}
|
||||
require.NotNil(t, set, "Created set is nil")
|
||||
|
||||
// Verify set properties
|
||||
assert.Equal(t, setName, set.Name, "Set name mismatch")
|
||||
assert.Equal(t, r.workTable, set.Table, "Set table mismatch")
|
||||
assert.True(t, set.Interval, "Set interval property should be true")
|
||||
assert.Equal(t, nftables.TypeIPAddr, set.KeyType, "Set key type mismatch")
|
||||
|
||||
// Fetch the created set from nftables
|
||||
fetchedSet, err := r.conn.GetSetByName(r.workTable, setName)
|
||||
require.NoError(t, err, "Failed to fetch created set")
|
||||
require.NotNil(t, fetchedSet, "Fetched set is nil")
|
||||
|
||||
// Verify set elements
|
||||
elements, err := r.conn.GetSetElements(fetchedSet)
|
||||
require.NoError(t, err, "Failed to get set elements")
|
||||
|
||||
// Count the number of unique prefixes (excluding interval end markers)
|
||||
uniquePrefixes := make(map[string]bool)
|
||||
for _, elem := range elements {
|
||||
if !elem.IntervalEnd {
|
||||
ip := netip.AddrFrom4(*(*[4]byte)(elem.Key))
|
||||
uniquePrefixes[ip.String()] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Check against expected merged prefixes
|
||||
expectedCount := len(tt.expected)
|
||||
if expectedCount == 0 {
|
||||
expectedCount = len(tt.sources)
|
||||
}
|
||||
assert.Equal(t, expectedCount, len(uniquePrefixes), "Number of unique prefixes in set doesn't match expected")
|
||||
|
||||
// Verify each expected prefix is in the set
|
||||
for _, expected := range tt.expected {
|
||||
found := false
|
||||
for _, elem := range elements {
|
||||
if !elem.IntervalEnd {
|
||||
ip := netip.AddrFrom4(*(*[4]byte)(elem.Key))
|
||||
if expected.Contains(ip) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Expected prefix %s not found in set", expected)
|
||||
}
|
||||
|
||||
r.conn.DelSet(set)
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
t.Logf("Failed to delete set: %v", err)
|
||||
printNftSets()
|
||||
}
|
||||
require.NoError(t, err, "Failed to delete set")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
|
||||
t.Helper()
|
||||
|
||||
assert.NotNil(t, rule, "Rule should not be nil")
|
||||
|
||||
// Verify sources and destination
|
||||
if expectSet {
|
||||
assert.True(t, containsSetLookup(rule.Exprs), "Rule should contain set lookup for multiple sources")
|
||||
} else if len(sources) == 1 && sources[0].Bits() != 0 {
|
||||
if direction == firewall.RuleDirectionIN {
|
||||
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], true), "Rule should contain source CIDR matcher for %s", sources[0])
|
||||
} else {
|
||||
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], false), "Rule should contain destination CIDR matcher for %s", sources[0])
|
||||
}
|
||||
}
|
||||
|
||||
if direction == firewall.RuleDirectionIN {
|
||||
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, false), "Rule should contain destination CIDR matcher for %s", destination)
|
||||
} else {
|
||||
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, true), "Rule should contain source CIDR matcher for %s", destination)
|
||||
}
|
||||
|
||||
// Verify protocol
|
||||
if proto != firewall.ProtocolALL {
|
||||
assert.True(t, containsProtocol(rule.Exprs, proto), "Rule should contain protocol matcher for %s", proto)
|
||||
}
|
||||
|
||||
// Verify ports
|
||||
if sPort != nil {
|
||||
assert.True(t, containsPort(rule.Exprs, sPort, true), "Rule should contain source port matcher for %v", sPort)
|
||||
}
|
||||
if dPort != nil {
|
||||
assert.True(t, containsPort(rule.Exprs, dPort, false), "Rule should contain destination port matcher for %v", dPort)
|
||||
}
|
||||
|
||||
// Verify action
|
||||
assert.True(t, containsAction(rule.Exprs, action), "Rule should contain correct action: %s", action)
|
||||
}
|
||||
|
||||
func containsSetLookup(exprs []expr.Any) bool {
|
||||
for _, e := range exprs {
|
||||
if _, ok := e.(*expr.Lookup); ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func containsCIDRMatcher(exprs []expr.Any, prefix netip.Prefix, isSource bool) bool {
|
||||
var offset uint32
|
||||
if isSource {
|
||||
offset = 12 // src offset
|
||||
} else {
|
||||
offset = 16 // dst offset
|
||||
}
|
||||
|
||||
var payloadFound, bitwiseFound, cmpFound bool
|
||||
for _, e := range exprs {
|
||||
switch ex := e.(type) {
|
||||
case *expr.Payload:
|
||||
if ex.Base == expr.PayloadBaseNetworkHeader && ex.Offset == offset && ex.Len == 4 {
|
||||
payloadFound = true
|
||||
}
|
||||
case *expr.Bitwise:
|
||||
if ex.Len == 4 && len(ex.Mask) == 4 && len(ex.Xor) == 4 {
|
||||
bitwiseFound = true
|
||||
}
|
||||
case *expr.Cmp:
|
||||
if ex.Op == expr.CmpOpEq && len(ex.Data) == 4 {
|
||||
cmpFound = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return (payloadFound && bitwiseFound && cmpFound) || prefix.Bits() == 0
|
||||
}
|
||||
|
||||
func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
|
||||
var offset uint32 = 2 // Default offset for destination port
|
||||
if isSource {
|
||||
offset = 0 // Offset for source port
|
||||
}
|
||||
|
||||
var payloadFound, portMatchFound bool
|
||||
for _, e := range exprs {
|
||||
switch ex := e.(type) {
|
||||
case *expr.Payload:
|
||||
if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
|
||||
payloadFound = true
|
||||
}
|
||||
case *expr.Cmp:
|
||||
if port.IsRange {
|
||||
if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte {
|
||||
portMatchFound = true
|
||||
}
|
||||
} else {
|
||||
if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
|
||||
portValue := binary.BigEndian.Uint16(ex.Data)
|
||||
for _, p := range port.Values {
|
||||
if uint16(p) == portValue {
|
||||
portMatchFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if payloadFound && portMatchFound {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
|
||||
var metaFound, cmpFound bool
|
||||
expectedProto, _ := protoToInt(proto)
|
||||
for _, e := range exprs {
|
||||
switch ex := e.(type) {
|
||||
case *expr.Meta:
|
||||
if ex.Key == expr.MetaKeyL4PROTO {
|
||||
metaFound = true
|
||||
}
|
||||
case *expr.Cmp:
|
||||
if ex.Op == expr.CmpOpEq && len(ex.Data) == 1 && ex.Data[0] == expectedProto {
|
||||
cmpFound = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return metaFound && cmpFound
|
||||
}
|
||||
|
||||
func containsAction(exprs []expr.Any, action firewall.Action) bool {
|
||||
for _, e := range exprs {
|
||||
if verdict, ok := e.(*expr.Verdict); ok {
|
||||
switch action {
|
||||
case firewall.ActionAccept:
|
||||
return verdict.Kind == expr.VerdictAccept
|
||||
case firewall.ActionDrop:
|
||||
return verdict.Kind == expr.VerdictDrop
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
||||
func check() int {
|
||||
nf := nftables.Conn{}
|
||||
@@ -685,12 +250,12 @@ func createWorkTable() (*nftables.Table, error) {
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == tableNameNetbird {
|
||||
if t.Name == tableName {
|
||||
sConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
|
||||
table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
|
||||
err = sConn.Flush()
|
||||
|
||||
return table, err
|
||||
@@ -708,7 +273,7 @@ func deleteWorkTable() {
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == tableNameNetbird {
|
||||
if t.Name == tableName {
|
||||
sConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
//go:build !android
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
import firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
|
||||
var (
|
||||
InsertRuleTestCases = []struct {
|
||||
@@ -15,8 +13,8 @@ var (
|
||||
Name: "Insert Forwarding IPV4 Rule",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
||||
Source: "100.100.100.1/32",
|
||||
Destination: "100.100.200.0/24",
|
||||
Masquerade: false,
|
||||
},
|
||||
},
|
||||
@@ -24,8 +22,8 @@ var (
|
||||
Name: "Insert Forwarding And Nat IPV4 Rules",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
||||
Source: "100.100.100.1/32",
|
||||
Destination: "100.100.200.0/24",
|
||||
Masquerade: true,
|
||||
},
|
||||
},
|
||||
@@ -40,8 +38,8 @@ var (
|
||||
Name: "Remove Forwarding And Nat IPV4 Rules",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
||||
Source: "100.100.100.1/32",
|
||||
Destination: "100.100.200.0/24",
|
||||
Masquerade: true,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -3,7 +3,6 @@ package uspfilter
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
@@ -12,8 +11,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const layerTypeAll = 0
|
||||
@@ -24,7 +22,7 @@ var (
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
SetFilter(device.PacketFilter) error
|
||||
SetFilter(iface.PacketFilter) error
|
||||
Address() iface.WGAddress
|
||||
}
|
||||
|
||||
@@ -105,26 +103,26 @@ func (m *Manager) IsServerRouteSupported() bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
}
|
||||
return m.nativeFirewall.AddNatRule(pair)
|
||||
return m.nativeFirewall.InsertRoutingRules(pair)
|
||||
}
|
||||
|
||||
// RemoveNatRule removes a routing firewall rule
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
// RemoveRoutingRules removes a routing firewall rule
|
||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
}
|
||||
return m.nativeFirewall.RemoveNatRule(pair)
|
||||
return m.nativeFirewall.RemoveRoutingRules(pair)
|
||||
}
|
||||
|
||||
// AddPeerFiltering rule to the firewall
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
func (m *Manager) AddFiltering(
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
@@ -190,22 +188,8 @@ func (m *Manager) AddPeerFiltering(
|
||||
return []firewall.Rule{&r}, nil
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(sources [] netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action ) (firewall.Rule, error) {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil, errRouteNotSupported
|
||||
}
|
||||
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
}
|
||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -231,11 +215,6 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetLegacyManagement doesn't need to be implemented for this manager
|
||||
func (m *Manager) SetLegacyManagement(_ bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
@@ -416,7 +395,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
rule := r
|
||||
return m.DeletePeerRule(&rule)
|
||||
return m.DeleteRule(&rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -424,7 +403,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
rule := r
|
||||
return m.DeletePeerRule(&rule)
|
||||
return m.DeleteRule(&rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,16 +11,15 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
type IFaceMock struct {
|
||||
SetFilterFunc func(device.PacketFilter) error
|
||||
SetFilterFunc func(iface.PacketFilter) error
|
||||
AddressFunc func() iface.WGAddress
|
||||
}
|
||||
|
||||
func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
|
||||
func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
|
||||
if i.SetFilterFunc == nil {
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
@@ -36,7 +35,7 @@ func (i *IFaceMock) Address() iface.WGAddress {
|
||||
|
||||
func TestManagerCreate(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
@@ -50,10 +49,10 @@ func TestManagerCreate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
func TestManagerAddFiltering(t *testing.T) {
|
||||
isSetFilterCalled := false
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error {
|
||||
SetFilterFunc: func(iface.PacketFilter) error {
|
||||
isSetFilterCalled = true
|
||||
return nil
|
||||
},
|
||||
@@ -72,7 +71,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -91,7 +90,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
|
||||
func TestManagerDeleteRule(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
@@ -107,7 +106,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -120,14 +119,14 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
action = fw.ActionDrop
|
||||
comment = "Test rule 2"
|
||||
|
||||
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, r := range rule {
|
||||
err = m.DeletePeerRule(r)
|
||||
err = m.DeleteRule(r)
|
||||
if err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
return
|
||||
@@ -141,7 +140,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, r := range rule2 {
|
||||
err = m.DeletePeerRule(r)
|
||||
err = m.DeleteRule(r)
|
||||
if err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
return
|
||||
@@ -237,7 +236,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
|
||||
func TestManagerReset(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
@@ -253,7 +252,7 @@ func TestManagerReset(t *testing.T) {
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -272,7 +271,7 @@ func TestManagerReset(t *testing.T) {
|
||||
|
||||
func TestNotMatchByIP(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
@@ -291,7 +290,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
action := fw.ActionAccept
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment)
|
||||
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -340,7 +339,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
func TestRemovePacketHook(t *testing.T) {
|
||||
// creating mock iface
|
||||
iface := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
// creating manager instance
|
||||
@@ -389,7 +388,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||
// just check on the local interface
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||
}
|
||||
manager, err := Create(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
@@ -407,9 +406,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
package configurer
|
||||
|
||||
import "errors"
|
||||
|
||||
var ErrPeerNotFound = errors.New("peer not found")
|
||||
@@ -1,9 +0,0 @@
|
||||
package configurer
|
||||
|
||||
import "time"
|
||||
|
||||
type WGStats struct {
|
||||
LastHandshake time.Time
|
||||
TxBytes int64
|
||||
RxBytes int64
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
//go:build !android
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
type WGTunDevice interface {
|
||||
Create() (device.WGConfigurer, error)
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(address WGAddress) error
|
||||
WgAddress() WGAddress
|
||||
DeviceName() string
|
||||
Close() error
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
)
|
||||
|
||||
type WGConfigurer interface {
|
||||
ConfigureInterface(privateKey string, port int) error
|
||||
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeer(peerKey string) error
|
||||
AddAllowedIP(peerKey string, allowedIP string) error
|
||||
RemoveAllowedIP(peerKey string, allowedIP string) error
|
||||
Close()
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
}
|
||||
@@ -1,4 +0,0 @@
|
||||
package device
|
||||
|
||||
// CustomWindowsGUIDString is a custom GUID string for the interface
|
||||
var CustomWindowsGUIDString string
|
||||
@@ -1,16 +0,0 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
type WGTunDevice interface {
|
||||
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(address WGAddress) error
|
||||
WgAddress() WGAddress
|
||||
DeviceName() string
|
||||
Close() error
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package id
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
type RuleID string
|
||||
|
||||
func (r RuleID) GetRuleID() string {
|
||||
return string(r)
|
||||
}
|
||||
|
||||
func GenerateRouteRuleKey(
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto manager.Protocol,
|
||||
sPort *manager.Port,
|
||||
dPort *manager.Port,
|
||||
action manager.Action,
|
||||
) RuleID {
|
||||
return RuleID(fmt.Sprintf("%s-%s-%s-%s-%s-%d", sources, destination, proto, sPort, dPort, action))
|
||||
}
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -13,7 +12,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
@@ -25,18 +23,16 @@ type Manager interface {
|
||||
|
||||
// DefaultManager uses firewall manager to handle
|
||||
type DefaultManager struct {
|
||||
firewall firewall.Manager
|
||||
ipsetCounter int
|
||||
peerRulesPairs map[id.RuleID][]firewall.Rule
|
||||
routeRules map[id.RuleID]struct{}
|
||||
mutex sync.Mutex
|
||||
firewall firewall.Manager
|
||||
ipsetCounter int
|
||||
rulesPairs map[string][]firewall.Rule
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
|
||||
return &DefaultManager{
|
||||
firewall: fm,
|
||||
peerRulesPairs: make(map[id.RuleID][]firewall.Rule),
|
||||
routeRules: make(map[id.RuleID]struct{}),
|
||||
firewall: fm,
|
||||
rulesPairs: make(map[string][]firewall.Rule),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,7 +46,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
total := 0
|
||||
for _, pairs := range d.peerRulesPairs {
|
||||
for _, pairs := range d.rulesPairs {
|
||||
total += len(pairs)
|
||||
}
|
||||
log.Infof(
|
||||
@@ -63,34 +59,21 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
return
|
||||
}
|
||||
|
||||
d.applyPeerACLs(networkMap)
|
||||
defer func() {
|
||||
if err := d.firewall.Flush(); err != nil {
|
||||
log.Error("failed to flush firewall rules: ", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
|
||||
// then the mgmt server is older than the client, and we need to allow all traffic for routes
|
||||
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
|
||||
if err := d.firewall.SetLegacyManagement(isLegacy); err != nil {
|
||||
log.Errorf("failed to set legacy management flag: %v", err)
|
||||
}
|
||||
|
||||
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil {
|
||||
log.Errorf("Failed to apply route ACLs: %v", err)
|
||||
}
|
||||
|
||||
if err := d.firewall.Flush(); err != nil {
|
||||
log.Error("failed to flush firewall rules: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
rules, squashedProtocols := d.squashAcceptRules(networkMap)
|
||||
|
||||
enableSSH := networkMap.PeerConfig != nil &&
|
||||
enableSSH := (networkMap.PeerConfig != nil &&
|
||||
networkMap.PeerConfig.SshConfig != nil &&
|
||||
networkMap.PeerConfig.SshConfig.SshEnabled
|
||||
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
|
||||
networkMap.PeerConfig.SshConfig.SshEnabled)
|
||||
if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok {
|
||||
enableSSH = enableSSH && !ok
|
||||
}
|
||||
if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok {
|
||||
if _, ok := squashedProtocols[mgmProto.FirewallRule_TCP]; ok {
|
||||
enableSSH = enableSSH && !ok
|
||||
}
|
||||
|
||||
@@ -100,9 +83,9 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
if enableSSH {
|
||||
rules = append(rules, &mgmProto.FirewallRule{
|
||||
PeerIP: "0.0.0.0",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_TCP,
|
||||
Port: strconv.Itoa(ssh.DefaultSSHPort),
|
||||
})
|
||||
}
|
||||
@@ -114,20 +97,20 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
rules = append(rules,
|
||||
&mgmProto.FirewallRule{
|
||||
PeerIP: "0.0.0.0",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
&mgmProto.FirewallRule{
|
||||
PeerIP: "0.0.0.0",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
newRulePairs := make(map[id.RuleID][]firewall.Rule)
|
||||
newRulePairs := make(map[string][]firewall.Rule)
|
||||
ipsetByRuleSelectors := make(map[string]string)
|
||||
|
||||
for _, r := range rules {
|
||||
@@ -147,97 +130,29 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
break
|
||||
}
|
||||
if len(rules) > 0 {
|
||||
d.peerRulesPairs[pairID] = rulePair
|
||||
d.rulesPairs[pairID] = rulePair
|
||||
newRulePairs[pairID] = rulePair
|
||||
}
|
||||
}
|
||||
|
||||
for pairID, rules := range d.peerRulesPairs {
|
||||
for pairID, rules := range d.rulesPairs {
|
||||
if _, ok := newRulePairs[pairID]; !ok {
|
||||
for _, rule := range rules {
|
||||
if err := d.firewall.DeletePeerRule(rule); err != nil {
|
||||
log.Errorf("failed to delete peer firewall rule: %v", err)
|
||||
if err := d.firewall.DeleteRule(rule); err != nil {
|
||||
log.Errorf("failed to delete firewall rule: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
delete(d.peerRulesPairs, pairID)
|
||||
delete(d.rulesPairs, pairID)
|
||||
}
|
||||
}
|
||||
d.peerRulesPairs = newRulePairs
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
|
||||
var newRouteRules = make(map[id.RuleID]struct{})
|
||||
for _, rule := range rules {
|
||||
id, err := d.applyRouteACL(rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply route ACL: %w", err)
|
||||
}
|
||||
newRouteRules[id] = struct{}{}
|
||||
}
|
||||
|
||||
for id := range d.routeRules {
|
||||
if _, ok := newRouteRules[id]; !ok {
|
||||
if err := d.firewall.DeleteRouteRule(id); err != nil {
|
||||
log.Errorf("failed to delete route firewall rule: %v", err)
|
||||
continue
|
||||
}
|
||||
delete(d.routeRules, id)
|
||||
}
|
||||
}
|
||||
d.routeRules = newRouteRules
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
|
||||
if len(rule.SourceRanges) == 0 {
|
||||
return "", fmt.Errorf("source ranges is empty")
|
||||
}
|
||||
|
||||
var sources []netip.Prefix
|
||||
for _, sourceRange := range rule.SourceRanges {
|
||||
source, err := netip.ParsePrefix(sourceRange)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse source range: %w", err)
|
||||
}
|
||||
sources = append(sources, source)
|
||||
}
|
||||
|
||||
var destination netip.Prefix
|
||||
if rule.IsDynamic {
|
||||
destination = getDefault(sources[0])
|
||||
} else {
|
||||
var err error
|
||||
destination, err = netip.ParsePrefix(rule.Destination)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse destination: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
protocol, err := convertToFirewallProtocol(rule.Protocol)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid protocol: %w", err)
|
||||
}
|
||||
|
||||
action, err := convertFirewallAction(rule.Action)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid action: %w", err)
|
||||
}
|
||||
|
||||
dPorts := convertPortInfo(rule.PortInfo)
|
||||
|
||||
addedRule, err := d.firewall.AddRouteFiltering(sources, destination, protocol, nil, dPorts, action)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("add route rule: %w", err)
|
||||
}
|
||||
|
||||
return id.RuleID(addedRule.GetRuleID()), nil
|
||||
d.rulesPairs = newRulePairs
|
||||
}
|
||||
|
||||
func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
r *mgmProto.FirewallRule,
|
||||
ipsetName string,
|
||||
) (id.RuleID, []firewall.Rule, error) {
|
||||
) (string, []firewall.Rule, error) {
|
||||
ip := net.ParseIP(r.PeerIP)
|
||||
if ip == nil {
|
||||
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
@@ -264,16 +179,16 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
}
|
||||
}
|
||||
|
||||
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action, "")
|
||||
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
|
||||
ruleID := d.getRuleID(ip, protocol, int(r.Direction), port, action, "")
|
||||
if rulesPair, ok := d.rulesPairs[ruleID]; ok {
|
||||
return ruleID, rulesPair, nil
|
||||
}
|
||||
|
||||
var rules []firewall.Rule
|
||||
switch r.Direction {
|
||||
case mgmProto.RuleDirection_IN:
|
||||
case mgmProto.FirewallRule_IN:
|
||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
||||
case mgmProto.RuleDirection_OUT:
|
||||
case mgmProto.FirewallRule_OUT:
|
||||
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
||||
default:
|
||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||
@@ -295,7 +210,7 @@ func (d *DefaultManager) addInRules(
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var rules []firewall.Rule
|
||||
rule, err := d.firewall.AddPeerFiltering(
|
||||
rule, err := d.firewall.AddFiltering(
|
||||
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
@@ -306,7 +221,7 @@ func (d *DefaultManager) addInRules(
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
rule, err = d.firewall.AddPeerFiltering(
|
||||
rule, err = d.firewall.AddFiltering(
|
||||
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
@@ -324,7 +239,7 @@ func (d *DefaultManager) addOutRules(
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var rules []firewall.Rule
|
||||
rule, err := d.firewall.AddPeerFiltering(
|
||||
rule, err := d.firewall.AddFiltering(
|
||||
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
@@ -335,7 +250,7 @@ func (d *DefaultManager) addOutRules(
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
rule, err = d.firewall.AddPeerFiltering(
|
||||
rule, err = d.firewall.AddFiltering(
|
||||
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
@@ -344,21 +259,21 @@ func (d *DefaultManager) addOutRules(
|
||||
return append(rules, rule...), nil
|
||||
}
|
||||
|
||||
// getPeerRuleID() returns unique ID for the rule based on its parameters.
|
||||
func (d *DefaultManager) getPeerRuleID(
|
||||
// getRuleID() returns unique ID for the rule based on its parameters.
|
||||
func (d *DefaultManager) getRuleID(
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
direction int,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
comment string,
|
||||
) id.RuleID {
|
||||
) string {
|
||||
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
|
||||
if port != nil {
|
||||
idStr += port.String()
|
||||
}
|
||||
|
||||
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
|
||||
return hex.EncodeToString(md5.New().Sum([]byte(idStr)))
|
||||
}
|
||||
|
||||
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
|
||||
@@ -368,7 +283,7 @@ func (d *DefaultManager) getPeerRuleID(
|
||||
// but other has port definitions or has drop policy.
|
||||
func (d *DefaultManager) squashAcceptRules(
|
||||
networkMap *mgmProto.NetworkMap,
|
||||
) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) {
|
||||
) ([]*mgmProto.FirewallRule, map[mgmProto.FirewallRuleProtocol]struct{}) {
|
||||
totalIPs := 0
|
||||
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
|
||||
for range p.AllowedIps {
|
||||
@@ -376,14 +291,14 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
}
|
||||
}
|
||||
|
||||
type protoMatch map[mgmProto.RuleProtocol]map[string]int
|
||||
type protoMatch map[mgmProto.FirewallRuleProtocol]map[string]int
|
||||
|
||||
in := protoMatch{}
|
||||
out := protoMatch{}
|
||||
|
||||
// trace which type of protocols was squashed
|
||||
squashedRules := []*mgmProto.FirewallRule{}
|
||||
squashedProtocols := map[mgmProto.RuleProtocol]struct{}{}
|
||||
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
|
||||
|
||||
// this function we use to do calculation, can we squash the rules by protocol or not.
|
||||
// We summ amount of Peers IP for given protocol we found in original rules list.
|
||||
@@ -393,7 +308,7 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
//
|
||||
// We zeroed this to notify squash function that this protocol can't be squashed.
|
||||
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols protoMatch) {
|
||||
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
|
||||
drop := r.Action == mgmProto.FirewallRule_DROP || r.Port != ""
|
||||
if drop {
|
||||
protocols[r.Protocol] = map[string]int{}
|
||||
return
|
||||
@@ -421,7 +336,7 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
|
||||
for i, r := range networkMap.FirewallRules {
|
||||
// calculate squash for different directions
|
||||
if r.Direction == mgmProto.RuleDirection_IN {
|
||||
if r.Direction == mgmProto.FirewallRule_IN {
|
||||
addRuleToCalculationMap(i, r, in)
|
||||
} else {
|
||||
addRuleToCalculationMap(i, r, out)
|
||||
@@ -430,14 +345,14 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
|
||||
// order of squashing by protocol is important
|
||||
// only for their first element ALL, it must be done first
|
||||
protocolOrders := []mgmProto.RuleProtocol{
|
||||
mgmProto.RuleProtocol_ALL,
|
||||
mgmProto.RuleProtocol_ICMP,
|
||||
mgmProto.RuleProtocol_TCP,
|
||||
mgmProto.RuleProtocol_UDP,
|
||||
protocolOrders := []mgmProto.FirewallRuleProtocol{
|
||||
mgmProto.FirewallRule_ALL,
|
||||
mgmProto.FirewallRule_ICMP,
|
||||
mgmProto.FirewallRule_TCP,
|
||||
mgmProto.FirewallRule_UDP,
|
||||
}
|
||||
|
||||
squash := func(matches protoMatch, direction mgmProto.RuleDirection) {
|
||||
squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
|
||||
for _, protocol := range protocolOrders {
|
||||
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
|
||||
// don't squash if :
|
||||
@@ -450,12 +365,12 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
squashedRules = append(squashedRules, &mgmProto.FirewallRule{
|
||||
PeerIP: "0.0.0.0",
|
||||
Direction: direction,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: protocol,
|
||||
})
|
||||
squashedProtocols[protocol] = struct{}{}
|
||||
|
||||
if protocol == mgmProto.RuleProtocol_ALL {
|
||||
if protocol == mgmProto.FirewallRule_ALL {
|
||||
// if we have ALL traffic type squashed rule
|
||||
// it allows all other type of traffic, so we can stop processing
|
||||
break
|
||||
@@ -463,11 +378,11 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
}
|
||||
}
|
||||
|
||||
squash(in, mgmProto.RuleDirection_IN)
|
||||
squash(out, mgmProto.RuleDirection_OUT)
|
||||
squash(in, mgmProto.FirewallRule_IN)
|
||||
squash(out, mgmProto.FirewallRule_OUT)
|
||||
|
||||
// if all protocol was squashed everything is allow and we can ignore all other rules
|
||||
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
|
||||
if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok {
|
||||
return squashedRules, squashedProtocols
|
||||
}
|
||||
|
||||
@@ -497,26 +412,26 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
|
||||
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
|
||||
}
|
||||
|
||||
func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
|
||||
func (d *DefaultManager) rollBack(newRulePairs map[string][]firewall.Rule) {
|
||||
log.Debugf("rollback ACL to previous state")
|
||||
for _, rules := range newRulePairs {
|
||||
for _, rule := range rules {
|
||||
if err := d.firewall.DeletePeerRule(rule); err != nil {
|
||||
if err := d.firewall.DeleteRule(rule); err != nil {
|
||||
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
|
||||
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) (firewall.Protocol, error) {
|
||||
switch protocol {
|
||||
case mgmProto.RuleProtocol_TCP:
|
||||
case mgmProto.FirewallRule_TCP:
|
||||
return firewall.ProtocolTCP, nil
|
||||
case mgmProto.RuleProtocol_UDP:
|
||||
case mgmProto.FirewallRule_UDP:
|
||||
return firewall.ProtocolUDP, nil
|
||||
case mgmProto.RuleProtocol_ICMP:
|
||||
case mgmProto.FirewallRule_ICMP:
|
||||
return firewall.ProtocolICMP, nil
|
||||
case mgmProto.RuleProtocol_ALL:
|
||||
case mgmProto.FirewallRule_ALL:
|
||||
return firewall.ProtocolALL, nil
|
||||
default:
|
||||
return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
|
||||
@@ -527,41 +442,13 @@ func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) boo
|
||||
return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil
|
||||
}
|
||||
|
||||
func convertFirewallAction(action mgmProto.RuleAction) (firewall.Action, error) {
|
||||
func convertFirewallAction(action mgmProto.FirewallRuleAction) (firewall.Action, error) {
|
||||
switch action {
|
||||
case mgmProto.RuleAction_ACCEPT:
|
||||
case mgmProto.FirewallRule_ACCEPT:
|
||||
return firewall.ActionAccept, nil
|
||||
case mgmProto.RuleAction_DROP:
|
||||
case mgmProto.FirewallRule_DROP:
|
||||
return firewall.ActionDrop, nil
|
||||
default:
|
||||
return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action)
|
||||
}
|
||||
}
|
||||
|
||||
func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
|
||||
if portInfo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if portInfo.GetPort() != 0 {
|
||||
return &firewall.Port{
|
||||
Values: []int{int(portInfo.GetPort())},
|
||||
}
|
||||
}
|
||||
|
||||
if portInfo.GetRange() != nil {
|
||||
return &firewall.Port{
|
||||
IsRange: true,
|
||||
Values: []int{int(portInfo.GetRange().Start), int(portInfo.GetRange().End)},
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getDefault(prefix netip.Prefix) netip.Prefix {
|
||||
if prefix.Addr().Is6() {
|
||||
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
}
|
||||
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
}
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
@@ -19,16 +19,16 @@ func TestDefaultManager(t *testing.T) {
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_TCP,
|
||||
Port: "80",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_DROP,
|
||||
Protocol: mgmProto.FirewallRule_UDP,
|
||||
Port: "53",
|
||||
},
|
||||
},
|
||||
@@ -65,16 +65,16 @@ func TestDefaultManager(t *testing.T) {
|
||||
t.Run("apply firewall rules", func(t *testing.T) {
|
||||
acl.ApplyFiltering(networkMap)
|
||||
|
||||
if len(acl.peerRulesPairs) != 2 {
|
||||
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs)
|
||||
if len(acl.rulesPairs) != 2 {
|
||||
t.Errorf("firewall rules not applied: %v", acl.rulesPairs)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("add extra rules", func(t *testing.T) {
|
||||
existedPairs := map[string]struct{}{}
|
||||
for id := range acl.peerRulesPairs {
|
||||
existedPairs[id.GetRuleID()] = struct{}{}
|
||||
for id := range acl.rulesPairs {
|
||||
existedPairs[id] = struct{}{}
|
||||
}
|
||||
|
||||
// remove first rule
|
||||
@@ -83,24 +83,24 @@ func TestDefaultManager(t *testing.T) {
|
||||
networkMap.FirewallRules,
|
||||
&mgmProto.FirewallRule{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_ICMP,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_DROP,
|
||||
Protocol: mgmProto.FirewallRule_ICMP,
|
||||
},
|
||||
)
|
||||
|
||||
acl.ApplyFiltering(networkMap)
|
||||
|
||||
// we should have one old and one new rule in the existed rules
|
||||
if len(acl.peerRulesPairs) != 2 {
|
||||
if len(acl.rulesPairs) != 2 {
|
||||
t.Errorf("firewall rules not applied")
|
||||
return
|
||||
}
|
||||
|
||||
// check that old rule was removed
|
||||
previousCount := 0
|
||||
for id := range acl.peerRulesPairs {
|
||||
if _, ok := existedPairs[id.GetRuleID()]; ok {
|
||||
for id := range acl.rulesPairs {
|
||||
if _, ok := existedPairs[id]; ok {
|
||||
previousCount++
|
||||
}
|
||||
}
|
||||
@@ -113,15 +113,15 @@ func TestDefaultManager(t *testing.T) {
|
||||
networkMap.FirewallRules = networkMap.FirewallRules[:0]
|
||||
|
||||
networkMap.FirewallRulesIsEmpty = true
|
||||
if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 {
|
||||
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs))
|
||||
if acl.ApplyFiltering(networkMap); len(acl.rulesPairs) != 0 {
|
||||
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.rulesPairs))
|
||||
return
|
||||
}
|
||||
|
||||
networkMap.FirewallRulesIsEmpty = false
|
||||
acl.ApplyFiltering(networkMap)
|
||||
if len(acl.peerRulesPairs) != 2 {
|
||||
t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
|
||||
if len(acl.rulesPairs) != 2 {
|
||||
t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.rulesPairs))
|
||||
return
|
||||
}
|
||||
})
|
||||
@@ -138,51 +138,51 @@ func TestDefaultManagerSquashRules(t *testing.T) {
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -199,13 +199,13 @@ func TestDefaultManagerSquashRules(t *testing.T) {
|
||||
case r.PeerIP != "0.0.0.0":
|
||||
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
||||
return
|
||||
case r.Direction != mgmProto.RuleDirection_IN:
|
||||
case r.Direction != mgmProto.FirewallRule_IN:
|
||||
t.Errorf("direction should be IN, got: %v", r.Direction)
|
||||
return
|
||||
case r.Protocol != mgmProto.RuleProtocol_ALL:
|
||||
case r.Protocol != mgmProto.FirewallRule_ALL:
|
||||
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
||||
return
|
||||
case r.Action != mgmProto.RuleAction_ACCEPT:
|
||||
case r.Action != mgmProto.FirewallRule_ACCEPT:
|
||||
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
||||
return
|
||||
}
|
||||
@@ -215,13 +215,13 @@ func TestDefaultManagerSquashRules(t *testing.T) {
|
||||
case r.PeerIP != "0.0.0.0":
|
||||
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
||||
return
|
||||
case r.Direction != mgmProto.RuleDirection_OUT:
|
||||
case r.Direction != mgmProto.FirewallRule_OUT:
|
||||
t.Errorf("direction should be OUT, got: %v", r.Direction)
|
||||
return
|
||||
case r.Protocol != mgmProto.RuleProtocol_ALL:
|
||||
case r.Protocol != mgmProto.FirewallRule_ALL:
|
||||
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
||||
return
|
||||
case r.Action != mgmProto.RuleAction_ACCEPT:
|
||||
case r.Action != mgmProto.FirewallRule_ACCEPT:
|
||||
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
||||
return
|
||||
}
|
||||
@@ -238,51 +238,51 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_UDP,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -308,21 +308,21 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_UDP,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -357,8 +357,8 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
|
||||
acl.ApplyFiltering(networkMap)
|
||||
|
||||
if len(acl.peerRulesPairs) != 4 {
|
||||
t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
|
||||
if len(acl.rulesPairs) != 4 {
|
||||
t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.rulesPairs))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,8 +8,7 @@ import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
iface "github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
iface "github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
// MockIFaceMapper is a mock of IFaceMapper interface.
|
||||
@@ -78,7 +77,7 @@ func (mr *MockIFaceMapperMockRecorder) Name() *gomock.Call {
|
||||
}
|
||||
|
||||
// SetFilter mocks base method.
|
||||
func (m *MockIFaceMapper) SetFilter(arg0 device.PacketFilter) error {
|
||||
func (m *MockIFaceMapper) SetFilter(arg0 iface.PacketFilter) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SetFilter", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
|
||||
@@ -16,9 +16,9 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -17,14 +17,13 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
@@ -71,7 +70,7 @@ func (c *ConnectClient) RunWithProbes(
|
||||
|
||||
// RunOnAndroid with main logic on mobile system
|
||||
func (c *ConnectClient) RunOnAndroid(
|
||||
tunAdapter device.TunAdapter,
|
||||
tunAdapter iface.TunAdapter,
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsAddresses []string,
|
||||
@@ -206,7 +205,7 @@ func (c *ConnectClient) run(
|
||||
localPeerState := peer.LocalPeerState{
|
||||
IP: loginResp.GetPeerConfig().GetAddress(),
|
||||
PubKey: myPrivateKey.PublicKey().String(),
|
||||
KernelInterface: device.WireGuardModuleIsLoaded(),
|
||||
KernelInterface: iface.WireGuardModuleIsLoaded(),
|
||||
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
||||
}
|
||||
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
||||
@@ -269,6 +268,12 @@ func (c *ConnectClient) run(
|
||||
checks := loginResp.GetChecks()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
if c.engine != nil && c.engine.ctx.Err() != nil {
|
||||
log.Info("Stopping Netbird Engine")
|
||||
if err := c.engine.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
}
|
||||
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
|
||||
|
||||
c.engineMutex.Unlock()
|
||||
@@ -288,15 +293,6 @@ func (c *ConnectClient) run(
|
||||
}
|
||||
|
||||
<-engineCtx.Done()
|
||||
c.engineMutex.Lock()
|
||||
if c.engine != nil && c.engine.wgInterface != nil {
|
||||
log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name())
|
||||
if err := c.engine.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
c.engine = nil
|
||||
}
|
||||
c.engineMutex.Unlock()
|
||||
c.statusRecorder.ClientTeardown()
|
||||
|
||||
backOff.Reset()
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/iface/mocks"
|
||||
)
|
||||
|
||||
func TestResponseWriterLocalAddr(t *testing.T) {
|
||||
|
||||
@@ -15,18 +15,16 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
pfmock "github.com/netbirdio/netbird/iface/mocks"
|
||||
)
|
||||
|
||||
type mocWGIface struct {
|
||||
filter device.PacketFilter
|
||||
filter iface.PacketFilter
|
||||
}
|
||||
|
||||
func (w *mocWGIface) Name() string {
|
||||
@@ -45,11 +43,11 @@ func (w *mocWGIface) ToInterface() *net.Interface {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (w *mocWGIface) GetFilter() device.PacketFilter {
|
||||
func (w *mocWGIface) GetFilter() iface.PacketFilter {
|
||||
return w.filter
|
||||
}
|
||||
|
||||
func (w *mocWGIface) GetDevice() *device.FilteredDevice {
|
||||
func (w *mocWGIface) GetDevice() *iface.DeviceWrapper {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
@@ -61,13 +59,13 @@ func (w *mocWGIface) IsUserspaceBind() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *mocWGIface) SetFilter(filter device.PacketFilter) error {
|
||||
func (w *mocWGIface) SetFilter(filter iface.PacketFilter) error {
|
||||
w.filter = filter
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *mocWGIface) GetStats(_ string) (configurer.WGStats, error) {
|
||||
return configurer.WGStats{}, nil
|
||||
func (w *mocWGIface) GetStats(_ string) (iface.WGStats, error) {
|
||||
return iface.WGStats{}, nil
|
||||
}
|
||||
|
||||
var zoneRecords = []nbdns.SimpleRecord{
|
||||
|
||||
@@ -5,9 +5,7 @@ package dns
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
// WGIface defines subset methods of interface required for manager
|
||||
@@ -16,7 +14,7 @@ type WGIface interface {
|
||||
Address() iface.WGAddress
|
||||
ToInterface() *net.Interface
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
GetFilter() iface.PacketFilter
|
||||
GetDevice() *iface.DeviceWrapper
|
||||
GetStats(peerKey string) (iface.WGStats, error)
|
||||
}
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
import "github.com/netbirdio/netbird/iface"
|
||||
|
||||
// WGIface defines subset methods of interface required for manager
|
||||
type WGIface interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
GetFilter() iface.PacketFilter
|
||||
GetDevice() *iface.DeviceWrapper
|
||||
GetStats(peerKey string) (iface.WGStats, error)
|
||||
GetInterfaceGUIDString() (string, error)
|
||||
}
|
||||
|
||||
@@ -23,12 +23,9 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
@@ -39,6 +36,8 @@ import (
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/iface/bind"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
@@ -141,7 +140,7 @@ type Engine struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
wgInterface IWGIface
|
||||
wgInterface iface.IWGIface
|
||||
wgProxyFactory *wgproxy.Factory
|
||||
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
@@ -251,13 +250,6 @@ func (e *Engine) Stop() error {
|
||||
}
|
||||
log.Info("Network monitor: stopped")
|
||||
|
||||
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
||||
e.stopDNSServer()
|
||||
|
||||
if e.routeManager != nil {
|
||||
e.routeManager.Stop()
|
||||
}
|
||||
|
||||
err := e.removeAllPeers()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove all peers: %s", err)
|
||||
@@ -300,7 +292,7 @@ func (e *Engine) Start() error {
|
||||
e.wgInterface = wgIface
|
||||
|
||||
userspace := e.wgInterface.IsUserspaceBind()
|
||||
e.wgProxyFactory = wgproxy.NewFactory(userspace, e.config.WgPort)
|
||||
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort)
|
||||
|
||||
if e.config.RosenpassEnabled {
|
||||
log.Infof("rosenpass is enabled")
|
||||
@@ -326,7 +318,7 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
e.dnsServer = dnsServer
|
||||
|
||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface.(*iface.WGIface), e.statusRecorder, e.relayManager, initialRoutes)
|
||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes)
|
||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to initialize route manager: %s", err)
|
||||
@@ -627,7 +619,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{
|
||||
IP: e.config.WgAddr,
|
||||
PubKey: e.config.WgPrivateKey.PublicKey().String(),
|
||||
KernelInterface: device.WireGuardModuleIsLoaded(),
|
||||
KernelInterface: iface.WireGuardModuleIsLoaded(),
|
||||
FQDN: conf.GetFqdn(),
|
||||
})
|
||||
|
||||
@@ -712,11 +704,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Apply ACLs in the beginning to avoid security leaks
|
||||
if e.acl != nil {
|
||||
e.acl.ApplyFiltering(networkMap)
|
||||
}
|
||||
|
||||
protoRoutes := networkMap.GetRoutes()
|
||||
if protoRoutes == nil {
|
||||
protoRoutes = []*mgmProto.Route{}
|
||||
@@ -783,6 +770,10 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
}
|
||||
|
||||
if e.acl != nil {
|
||||
e.acl.ApplyFiltering(networkMap)
|
||||
}
|
||||
|
||||
e.networkSerial = serial
|
||||
|
||||
// Test received (upstream) servers for availability right away instead of upon usage.
|
||||
@@ -921,7 +912,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
|
||||
wgConfig := peer.WgConfig{
|
||||
RemoteKey: pubKey,
|
||||
WgListenPort: e.config.WgPort,
|
||||
WgInterface: e.wgInterface.(*iface.WGIface),
|
||||
WgInterface: e.wgInterface,
|
||||
AllowedIps: allowedIPs,
|
||||
PreSharedKey: e.config.PreSharedKey,
|
||||
}
|
||||
@@ -1123,12 +1114,18 @@ func (e *Engine) close() {
|
||||
}
|
||||
}
|
||||
|
||||
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
||||
e.stopDNSServer()
|
||||
|
||||
if e.routeManager != nil {
|
||||
e.routeManager.Stop()
|
||||
}
|
||||
|
||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||
if e.wgInterface != nil {
|
||||
if err := e.wgInterface.Close(); err != nil {
|
||||
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
|
||||
}
|
||||
e.wgInterface = nil
|
||||
}
|
||||
|
||||
if !isNil(e.sshServer) {
|
||||
@@ -1167,15 +1164,15 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
|
||||
log.Errorf("failed to create pion's stdnet: %s", err)
|
||||
}
|
||||
|
||||
var mArgs *device.MobileIFaceArguments
|
||||
var mArgs *iface.MobileIFaceArguments
|
||||
switch runtime.GOOS {
|
||||
case "android":
|
||||
mArgs = &device.MobileIFaceArguments{
|
||||
mArgs = &iface.MobileIFaceArguments{
|
||||
TunAdapter: e.mobileDep.TunAdapter,
|
||||
TunFd: int(e.mobileDep.FileDescriptor),
|
||||
}
|
||||
case "ios":
|
||||
mArgs = &device.MobileIFaceArguments{
|
||||
mArgs = &iface.MobileIFaceArguments{
|
||||
TunFd: int(e.mobileDep.FileDescriptor),
|
||||
}
|
||||
default:
|
||||
@@ -1396,7 +1393,7 @@ func (e *Engine) startNetworkMonitor() {
|
||||
}
|
||||
|
||||
// Set a new timer to debounce rapid network changes
|
||||
debounceTimer = time.AfterFunc(2*time.Second, func() {
|
||||
debounceTimer = time.AfterFunc(1*time.Second, func() {
|
||||
// This function is called after the debounce period
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
@@ -1427,11 +1424,6 @@ func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
|
||||
}
|
||||
|
||||
func (e *Engine) stopDNSServer() {
|
||||
if e.dnsServer == nil {
|
||||
return
|
||||
}
|
||||
e.dnsServer.Stop()
|
||||
e.dnsServer = nil
|
||||
err := fmt.Errorf("DNS server stopped")
|
||||
nsGroupStates := e.statusRecorder.GetDNSStates()
|
||||
for i := range nsGroupStates {
|
||||
@@ -1439,6 +1431,10 @@ func (e *Engine) stopDNSServer() {
|
||||
nsGroupStates[i].Error = err
|
||||
}
|
||||
e.statusRecorder.UpdateDNSStates(nsGroupStates)
|
||||
if e.dnsServer != nil {
|
||||
e.dnsServer.Stop()
|
||||
e.dnsServer = nil
|
||||
}
|
||||
}
|
||||
|
||||
// isChecksEqual checks if two slices of checks are equal.
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -24,15 +25,14 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/iface/bind"
|
||||
mgmt "github.com/netbirdio/netbird/management/client"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
@@ -242,13 +242,13 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
peer.NewRecorder("https://mgm"),
|
||||
nil)
|
||||
|
||||
wgIface := &MockWGIface{
|
||||
wgIface := &iface.MockWGIface{
|
||||
RemovePeerFunc: func(peerKey string) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
engine.wgInterface = wgIface
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface.(*iface.WGIface), engine.statusRecorder, relayMgr, nil)
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil)
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
@@ -823,6 +823,20 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
func TestEngine_MultiplePeers(t *testing.T) {
|
||||
// log.SetLevel(log.DebugLevel)
|
||||
|
||||
dir := t.TempDir()
|
||||
|
||||
err := util.CopyFileContents("../testdata/store.json", filepath.Join(dir, "store.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
err = os.Remove(filepath.Join(dir, "store.json")) //nolint
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
@@ -832,7 +846,7 @@ func TestEngine_MultiplePeers(t *testing.T) {
|
||||
return
|
||||
}
|
||||
defer sigServer.Stop()
|
||||
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sqlite")
|
||||
mgmtServer, mgmtAddr, err := startManagement(t, dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -860,7 +874,7 @@ func TestEngine_MultiplePeers(t *testing.T) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
guid := fmt.Sprintf("{%s}", uuid.New().String())
|
||||
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||
iface.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||
err = engine.Start()
|
||||
if err != nil {
|
||||
t.Errorf("unable to start engine for peer %d with error %v", j, err)
|
||||
@@ -1042,7 +1056,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
|
||||
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
||||
srv, err := signalServer.NewServer(otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
proto.RegisterSignalExchangeServer(s, srv)
|
||||
|
||||
@@ -1055,7 +1069,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
|
||||
func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
config := &server.Config{
|
||||
@@ -1080,7 +1094,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), testFile, config.Datadir)
|
||||
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
// MobileDependency collect all dependencies for mobile platform
|
||||
type MobileDependency struct {
|
||||
// Android only
|
||||
TunAdapter device.TunAdapter
|
||||
TunAdapter iface.TunAdapter
|
||||
IFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
NetworkChangeListener listener.NetworkChangeListener
|
||||
HostDNSAddresses []string
|
||||
|
||||
@@ -24,7 +24,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
defer func() {
|
||||
err := unix.Close(fd)
|
||||
if err != nil && !errors.Is(err, unix.EBADF) {
|
||||
log.Warnf("Network monitor: failed to close routing socket: %v", err)
|
||||
log.Errorf("Network monitor: failed to close routing socket: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -32,7 +32,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
<-ctx.Done()
|
||||
err := unix.Close(fd)
|
||||
if err != nil && !errors.Is(err, unix.EBADF) {
|
||||
log.Debugf("Network monitor: closed routing socket: %v", err)
|
||||
log.Debugf("Network monitor: closed routing socket")
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -45,12 +45,12 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
n, err := unix.Read(fd, buf)
|
||||
if err != nil {
|
||||
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
|
||||
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
|
||||
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if n < unix.SizeofRtMsghdr {
|
||||
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||
log.Errorf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||
route, err := parseRouteMessage(buf[:n])
|
||||
if err != nil {
|
||||
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
||||
log.Errorf("Network monitor: error parsing routing message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -15,10 +15,9 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
@@ -32,14 +31,12 @@ const (
|
||||
connPriorityRelay ConnPriority = 1
|
||||
connPriorityICETurn ConnPriority = 1
|
||||
connPriorityICEP2P ConnPriority = 2
|
||||
|
||||
reconnectMaxElapsedTime = 30 * time.Minute
|
||||
)
|
||||
|
||||
type WgConfig struct {
|
||||
WgListenPort int
|
||||
RemoteKey string
|
||||
WgInterface *iface.WGIface
|
||||
WgInterface iface.IWGIface
|
||||
AllowedIps string
|
||||
PreSharedKey *wgtypes.Key
|
||||
}
|
||||
@@ -82,8 +79,9 @@ type Conn struct {
|
||||
config ConnConfig
|
||||
statusRecorder *Status
|
||||
wgProxyFactory *wgproxy.Factory
|
||||
wgProxyICE wgproxy.Proxy
|
||||
wgProxyRelay wgproxy.Proxy
|
||||
signaler *Signaler
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
relayManager *relayClient.Manager
|
||||
allowedIPsIP string
|
||||
handshaker *Handshaker
|
||||
@@ -104,14 +102,11 @@ type Conn struct {
|
||||
beforeAddPeerHooks []nbnet.AddHookFunc
|
||||
afterRemovePeerHooks []nbnet.RemoveHookFunc
|
||||
|
||||
wgProxyICE wgproxy.Proxy
|
||||
wgProxyRelay wgproxy.Proxy
|
||||
endpointRelay *net.UDPAddr
|
||||
|
||||
// for reconnection operations
|
||||
iCEDisconnected chan bool
|
||||
relayDisconnected chan bool
|
||||
connMonitor *ConnMonitor
|
||||
reconnectCh <-chan struct{}
|
||||
}
|
||||
|
||||
// NewConn creates a new not opened Conn to the remote peer.
|
||||
@@ -127,31 +122,21 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
|
||||
connLog := log.WithField("peer", config.Key)
|
||||
|
||||
var conn = &Conn{
|
||||
log: connLog,
|
||||
ctx: ctx,
|
||||
ctxCancel: ctxCancel,
|
||||
config: config,
|
||||
statusRecorder: statusRecorder,
|
||||
wgProxyFactory: wgProxyFactory,
|
||||
signaler: signaler,
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
relayManager: relayManager,
|
||||
allowedIPsIP: allowedIPsIP.String(),
|
||||
statusRelay: NewAtomicConnStatus(),
|
||||
statusICE: NewAtomicConnStatus(),
|
||||
|
||||
log: connLog,
|
||||
ctx: ctx,
|
||||
ctxCancel: ctxCancel,
|
||||
config: config,
|
||||
statusRecorder: statusRecorder,
|
||||
wgProxyFactory: wgProxyFactory,
|
||||
signaler: signaler,
|
||||
relayManager: relayManager,
|
||||
allowedIPsIP: allowedIPsIP.String(),
|
||||
statusRelay: NewAtomicConnStatus(),
|
||||
statusICE: NewAtomicConnStatus(),
|
||||
iCEDisconnected: make(chan bool, 1),
|
||||
relayDisconnected: make(chan bool, 1),
|
||||
}
|
||||
|
||||
conn.connMonitor, conn.reconnectCh = NewConnMonitor(
|
||||
signaler,
|
||||
iFaceDiscover,
|
||||
config,
|
||||
conn.relayDisconnected,
|
||||
conn.iCEDisconnected,
|
||||
)
|
||||
|
||||
rFns := WorkerRelayCallbacks{
|
||||
OnConnReady: conn.relayConnectionIsReady,
|
||||
OnDisconnected: conn.onWorkerRelayStateDisconnected,
|
||||
@@ -214,8 +199,6 @@ func (conn *Conn) startHandshakeAndReconnect() {
|
||||
conn.log.Errorf("failed to send initial offer: %v", err)
|
||||
}
|
||||
|
||||
go conn.connMonitor.Start(conn.ctx)
|
||||
|
||||
if conn.workerRelay.IsController() {
|
||||
conn.reconnectLoopWithRetry()
|
||||
} else {
|
||||
@@ -256,7 +239,8 @@ func (conn *Conn) Close() {
|
||||
conn.wgProxyICE = nil
|
||||
}
|
||||
|
||||
if err := conn.removeWgPeer(); err != nil {
|
||||
err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
|
||||
@@ -324,14 +308,12 @@ func (conn *Conn) reconnectLoopWithRetry() {
|
||||
// With it, we can decrease to send necessary offer
|
||||
select {
|
||||
case <-conn.ctx.Done():
|
||||
return
|
||||
case <-time.After(3 * time.Second):
|
||||
}
|
||||
|
||||
ticker := conn.prepareExponentTicker()
|
||||
defer ticker.Stop()
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
for {
|
||||
select {
|
||||
case t := <-ticker.C:
|
||||
@@ -359,11 +341,20 @@ func (conn *Conn) reconnectLoopWithRetry() {
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to do handshake: %v", err)
|
||||
}
|
||||
|
||||
case <-conn.reconnectCh:
|
||||
case changed := <-conn.relayDisconnected:
|
||||
if !changed {
|
||||
continue
|
||||
}
|
||||
conn.log.Debugf("Relay state changed, reset reconnect timer")
|
||||
ticker.Stop()
|
||||
ticker = conn.prepareExponentTicker()
|
||||
case changed := <-conn.iCEDisconnected:
|
||||
if !changed {
|
||||
continue
|
||||
}
|
||||
conn.log.Debugf("ICE state changed, reset reconnect timer")
|
||||
ticker.Stop()
|
||||
ticker = conn.prepareExponentTicker()
|
||||
|
||||
case <-conn.ctx.Done():
|
||||
conn.log.Debugf("context is done, stop reconnect loop")
|
||||
return
|
||||
@@ -374,10 +365,10 @@ func (conn *Conn) reconnectLoopWithRetry() {
|
||||
func (conn *Conn) prepareExponentTicker() *backoff.Ticker {
|
||||
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 800 * time.Millisecond,
|
||||
RandomizationFactor: 0.1,
|
||||
RandomizationFactor: 0.01,
|
||||
Multiplier: 2,
|
||||
MaxInterval: conn.config.Timeout,
|
||||
MaxElapsedTime: reconnectMaxElapsedTime,
|
||||
MaxElapsedTime: 0,
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}, conn.ctx)
|
||||
@@ -428,59 +419,54 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
|
||||
|
||||
conn.log.Debugf("ICE connection is ready")
|
||||
|
||||
conn.statusICE.Set(StatusConnected)
|
||||
|
||||
defer conn.updateIceState(iceConnInfo)
|
||||
|
||||
if conn.currentConnPriority > priority {
|
||||
conn.statusICE.Set(StatusConnected)
|
||||
conn.updateIceState(iceConnInfo)
|
||||
return
|
||||
}
|
||||
|
||||
conn.log.Infof("set ICE to active connection")
|
||||
|
||||
var (
|
||||
ep *net.UDPAddr
|
||||
wgProxy wgproxy.Proxy
|
||||
err error
|
||||
)
|
||||
if iceConnInfo.RelayedOnLocal {
|
||||
wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn)
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||
return
|
||||
}
|
||||
ep = wgProxy.EndpointAddr()
|
||||
conn.wgProxyICE = wgProxy
|
||||
} else {
|
||||
directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
log.Errorf("failed to resolveUDPaddr")
|
||||
conn.handleConfigurationFailure(err, nil)
|
||||
return
|
||||
}
|
||||
ep = directEp
|
||||
endpoint, wgProxy, err := conn.getEndpointForICEConnInfo(iceConnInfo)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
|
||||
conn.log.Errorf("Before add peer hook failed: %v", err)
|
||||
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
||||
conn.log.Debugf("Conn resolved IP is %s for endopint %s", endpoint, endpointUdpAddr.IP)
|
||||
|
||||
conn.connIDICE = nbnet.GenerateConnID()
|
||||
for _, hook := range conn.beforeAddPeerHooks {
|
||||
if err := hook(conn.connIDICE, endpointUdpAddr.IP); err != nil {
|
||||
conn.log.Errorf("Before add peer hook failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
conn.workerRelay.DisableWgWatcher()
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
conn.wgProxyRelay.Pause()
|
||||
}
|
||||
|
||||
if wgProxy != nil {
|
||||
wgProxy.Work()
|
||||
}
|
||||
|
||||
if err = conn.configureWGEndpoint(ep); err != nil {
|
||||
conn.handleConfigurationFailure(err, wgProxy)
|
||||
err = conn.configureWGEndpoint(endpointUdpAddr)
|
||||
if err != nil {
|
||||
if wgProxy != nil {
|
||||
if err := wgProxy.CloseConn(); err != nil {
|
||||
conn.log.Warnf("Failed to close turn connection: %v", err)
|
||||
}
|
||||
}
|
||||
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
|
||||
return
|
||||
}
|
||||
wgConfigWorkaround()
|
||||
|
||||
if conn.wgProxyICE != nil {
|
||||
if err := conn.wgProxyICE.CloseConn(); err != nil {
|
||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||
}
|
||||
}
|
||||
conn.wgProxyICE = wgProxy
|
||||
|
||||
conn.currentConnPriority = priority
|
||||
conn.statusICE.Set(StatusConnected)
|
||||
conn.updateIceState(iceConnInfo)
|
||||
|
||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
||||
}
|
||||
|
||||
@@ -495,18 +481,11 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
||||
|
||||
conn.log.Tracef("ICE connection state changed to %s", newState)
|
||||
|
||||
if conn.wgProxyICE != nil {
|
||||
if err := conn.wgProxyICE.CloseConn(); err != nil {
|
||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// switch back to relay connection
|
||||
if conn.isReadyToUpgrade() {
|
||||
if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay {
|
||||
conn.log.Debugf("ICE disconnected, set Relay to active connection")
|
||||
conn.wgProxyRelay.Work()
|
||||
|
||||
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
|
||||
err := conn.configureWGEndpoint(conn.endpointRelay)
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to switch to relay conn: %v", err)
|
||||
}
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
@@ -516,7 +495,10 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
||||
changed := conn.statusICE.Get() != newState && newState != StatusConnecting
|
||||
conn.statusICE.Set(newState)
|
||||
|
||||
conn.notifyReconnectLoopICEDisconnected(changed)
|
||||
select {
|
||||
case conn.iCEDisconnected <- changed:
|
||||
default:
|
||||
}
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
@@ -537,48 +519,61 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
||||
|
||||
if conn.ctx.Err() != nil {
|
||||
if err := rci.relayedConn.Close(); err != nil {
|
||||
conn.log.Warnf("failed to close unnecessary relayed connection: %v", err)
|
||||
log.Warnf("failed to close unnecessary relayed connection: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
conn.log.Debugf("Relay connection has been established, setup the WireGuard")
|
||||
conn.log.Debugf("Relay connection is ready to use")
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
|
||||
wgProxy, err := conn.newProxy(rci.relayedConn)
|
||||
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
|
||||
endpoint, err := wgProxy.AddTurnConn(rci.relayedConn)
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
||||
return
|
||||
}
|
||||
conn.log.Infof("created new wgProxy for relay connection: %s", endpoint)
|
||||
|
||||
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
||||
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
||||
conn.endpointRelay = endpointUdpAddr
|
||||
conn.log.Debugf("conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
|
||||
|
||||
if conn.iceP2PIsActive() {
|
||||
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
||||
conn.wgProxyRelay = wgProxy
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
return
|
||||
defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
|
||||
if conn.currentConnPriority > connPriorityRelay {
|
||||
if conn.statusICE.Get() == StatusConnected {
|
||||
log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
|
||||
conn.log.Errorf("Before add peer hook failed: %v", err)
|
||||
conn.connIDRelay = nbnet.GenerateConnID()
|
||||
for _, hook := range conn.beforeAddPeerHooks {
|
||||
if err := hook(conn.connIDRelay, endpointUdpAddr.IP); err != nil {
|
||||
conn.log.Errorf("Before add peer hook failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
wgProxy.Work()
|
||||
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil {
|
||||
err = conn.configureWGEndpoint(endpointUdpAddr)
|
||||
if err != nil {
|
||||
if err := wgProxy.CloseConn(); err != nil {
|
||||
conn.log.Warnf("Failed to close relay connection: %v", err)
|
||||
}
|
||||
conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err)
|
||||
conn.log.Errorf("Failed to update wg peer configuration: %v", err)
|
||||
return
|
||||
}
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
|
||||
wgConfigWorkaround()
|
||||
conn.currentConnPriority = connPriorityRelay
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
if err := conn.wgProxyRelay.CloseConn(); err != nil {
|
||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||
}
|
||||
}
|
||||
conn.wgProxyRelay = wgProxy
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
conn.currentConnPriority = connPriorityRelay
|
||||
|
||||
conn.log.Infof("start to communicate with peer via relay")
|
||||
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
||||
}
|
||||
@@ -591,23 +586,25 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
|
||||
return
|
||||
}
|
||||
|
||||
conn.log.Debugf("relay connection is disconnected")
|
||||
|
||||
if conn.currentConnPriority == connPriorityRelay {
|
||||
conn.log.Debugf("clean up WireGuard config")
|
||||
if err := conn.removeWgPeer(); err != nil {
|
||||
if conn.wgProxyRelay != nil {
|
||||
log.Debugf("relayed connection is closed, clean up WireGuard config")
|
||||
err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
conn.endpointRelay = nil
|
||||
_ = conn.wgProxyRelay.CloseConn()
|
||||
conn.wgProxyRelay = nil
|
||||
}
|
||||
|
||||
changed := conn.statusRelay.Get() != StatusDisconnected
|
||||
conn.statusRelay.Set(StatusDisconnected)
|
||||
conn.notifyReconnectLoopRelayDisconnected(changed)
|
||||
|
||||
select {
|
||||
case conn.relayDisconnected <- changed:
|
||||
default:
|
||||
}
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
@@ -615,7 +612,9 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
|
||||
Relayed: conn.isRelayed(),
|
||||
ConnStatusUpdate: time.Now(),
|
||||
}
|
||||
if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil {
|
||||
|
||||
err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState)
|
||||
if err != nil {
|
||||
conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -681,7 +680,7 @@ func (conn *Conn) setStatusToDisconnected() {
|
||||
// todo rethink status updates
|
||||
conn.log.Debugf("error while updating peer's state, err: %v", err)
|
||||
}
|
||||
if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, configurer.WGStats{}); err != nil {
|
||||
if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, iface.WGStats{}); err != nil {
|
||||
conn.log.Debugf("failed to reset wireguard stats for peer: %s", err)
|
||||
}
|
||||
}
|
||||
@@ -751,16 +750,6 @@ func (conn *Conn) isConnected() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error {
|
||||
conn.connIDICE = nbnet.GenerateConnID()
|
||||
for _, hook := range conn.beforeAddPeerHooks {
|
||||
if err := hook(conn.connIDICE, ip); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *Conn) freeUpConnID() {
|
||||
if conn.connIDRelay != "" {
|
||||
for _, hook := range conn.afterRemovePeerHooks {
|
||||
@@ -781,52 +770,21 @@ func (conn *Conn) freeUpConnID() {
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||
conn.log.Debugf("setup proxied WireGuard connection")
|
||||
wgProxy := conn.wgProxyFactory.GetProxy()
|
||||
if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil {
|
||||
func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr, wgproxy.Proxy, error) {
|
||||
if !iceConnInfo.RelayedOnLocal {
|
||||
return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil
|
||||
}
|
||||
conn.log.Debugf("setup ice turn connection")
|
||||
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
|
||||
ep, err := wgProxy.AddTurnConn(iceConnInfo.RemoteConn)
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
return wgProxy, nil
|
||||
}
|
||||
|
||||
func (conn *Conn) isReadyToUpgrade() bool {
|
||||
return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay
|
||||
}
|
||||
|
||||
func (conn *Conn) iceP2PIsActive() bool {
|
||||
return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected
|
||||
}
|
||||
|
||||
func (conn *Conn) removeWgPeer() error {
|
||||
return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||
}
|
||||
|
||||
func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) {
|
||||
select {
|
||||
case conn.relayDisconnected <- changed:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) {
|
||||
select {
|
||||
case conn.iCEDisconnected <- changed:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
|
||||
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
|
||||
if wgProxy != nil {
|
||||
if ierr := wgProxy.CloseConn(); ierr != nil {
|
||||
conn.log.Warnf("Failed to close wg proxy: %v", ierr)
|
||||
if errClose := wgProxy.CloseConn(); errClose != nil {
|
||||
conn.log.Warnf("failed to close turn proxy connection: %v", errClose)
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
if conn.wgProxyRelay != nil {
|
||||
conn.wgProxyRelay.Work()
|
||||
}
|
||||
return ep, wgProxy, nil
|
||||
}
|
||||
|
||||
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
||||
|
||||
@@ -1,212 +0,0 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pion/ice/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
const (
|
||||
signalerMonitorPeriod = 5 * time.Second
|
||||
candidatesMonitorPeriod = 5 * time.Minute
|
||||
candidateGatheringTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type ConnMonitor struct {
|
||||
signaler *Signaler
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
config ConnConfig
|
||||
relayDisconnected chan bool
|
||||
iCEDisconnected chan bool
|
||||
reconnectCh chan struct{}
|
||||
currentCandidates []ice.Candidate
|
||||
candidatesMu sync.Mutex
|
||||
}
|
||||
|
||||
func NewConnMonitor(signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, config ConnConfig, relayDisconnected, iCEDisconnected chan bool) (*ConnMonitor, <-chan struct{}) {
|
||||
reconnectCh := make(chan struct{}, 1)
|
||||
cm := &ConnMonitor{
|
||||
signaler: signaler,
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
config: config,
|
||||
relayDisconnected: relayDisconnected,
|
||||
iCEDisconnected: iCEDisconnected,
|
||||
reconnectCh: reconnectCh,
|
||||
}
|
||||
return cm, reconnectCh
|
||||
}
|
||||
|
||||
func (cm *ConnMonitor) Start(ctx context.Context) {
|
||||
signalerReady := make(chan struct{}, 1)
|
||||
go cm.monitorSignalerReady(ctx, signalerReady)
|
||||
|
||||
localCandidatesChanged := make(chan struct{}, 1)
|
||||
go cm.monitorLocalCandidatesChanged(ctx, localCandidatesChanged)
|
||||
|
||||
for {
|
||||
select {
|
||||
case changed := <-cm.relayDisconnected:
|
||||
if !changed {
|
||||
continue
|
||||
}
|
||||
log.Debugf("Relay state changed, triggering reconnect")
|
||||
cm.triggerReconnect()
|
||||
|
||||
case changed := <-cm.iCEDisconnected:
|
||||
if !changed {
|
||||
continue
|
||||
}
|
||||
log.Debugf("ICE state changed, triggering reconnect")
|
||||
cm.triggerReconnect()
|
||||
|
||||
case <-signalerReady:
|
||||
log.Debugf("Signaler became ready, triggering reconnect")
|
||||
cm.triggerReconnect()
|
||||
|
||||
case <-localCandidatesChanged:
|
||||
log.Debugf("Local candidates changed, triggering reconnect")
|
||||
cm.triggerReconnect()
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *ConnMonitor) monitorSignalerReady(ctx context.Context, signalerReady chan<- struct{}) {
|
||||
if cm.signaler == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(signalerMonitorPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
lastReady := true
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
currentReady := cm.signaler.Ready()
|
||||
if !lastReady && currentReady {
|
||||
select {
|
||||
case signalerReady <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
lastReady = currentReady
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *ConnMonitor) monitorLocalCandidatesChanged(ctx context.Context, localCandidatesChanged chan<- struct{}) {
|
||||
ufrag, pwd, err := generateICECredentials()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to generate ICE credentials: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(candidatesMonitorPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := cm.handleCandidateTick(ctx, localCandidatesChanged, ufrag, pwd); err != nil {
|
||||
log.Warnf("Failed to handle candidate tick: %v", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *ConnMonitor) handleCandidateTick(ctx context.Context, localCandidatesChanged chan<- struct{}, ufrag string, pwd string) error {
|
||||
log.Debugf("Gathering ICE candidates")
|
||||
|
||||
transportNet, err := newStdNet(cm.iFaceDiscover, cm.config.ICEConfig.InterfaceBlackList)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create pion's stdnet: %s", err)
|
||||
}
|
||||
|
||||
agent, err := newAgent(cm.config, transportNet, candidateTypesP2P(), ufrag, pwd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create ICE agent: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := agent.Close(); err != nil {
|
||||
log.Warnf("Failed to close ICE agent: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
gatherDone := make(chan struct{})
|
||||
err = agent.OnCandidate(func(c ice.Candidate) {
|
||||
log.Tracef("Got candidate: %v", c)
|
||||
if c == nil {
|
||||
close(gatherDone)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("set ICE candidate handler: %w", err)
|
||||
}
|
||||
|
||||
if err := agent.GatherCandidates(); err != nil {
|
||||
return fmt.Errorf("gather ICE candidates: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("wait for gathering: %w", ctx.Err())
|
||||
case <-gatherDone:
|
||||
}
|
||||
|
||||
candidates, err := agent.GetLocalCandidates()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get local candidates: %w", err)
|
||||
}
|
||||
log.Tracef("Got candidates: %v", candidates)
|
||||
|
||||
if changed := cm.updateCandidates(candidates); changed {
|
||||
select {
|
||||
case localCandidatesChanged <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cm *ConnMonitor) updateCandidates(newCandidates []ice.Candidate) bool {
|
||||
cm.candidatesMu.Lock()
|
||||
defer cm.candidatesMu.Unlock()
|
||||
|
||||
if len(cm.currentCandidates) != len(newCandidates) {
|
||||
cm.currentCandidates = newCandidates
|
||||
return true
|
||||
}
|
||||
|
||||
for i, candidate := range cm.currentCandidates {
|
||||
if candidate.Address() != newCandidates[i].Address() {
|
||||
cm.currentCandidates = newCandidates
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (cm *ConnMonitor) triggerReconnect() {
|
||||
select {
|
||||
case cm.reconnectCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -9,9 +9,9 @@ import (
|
||||
|
||||
"github.com/magiconair/properties/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -44,7 +44,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_GetKey(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
@@ -59,7 +59,7 @@ func TestConn_GetKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
@@ -96,7 +96,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
@@ -132,7 +132,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
func TestConn_Status(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
)
|
||||
@@ -203,7 +203,7 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
||||
|
||||
state, ok := d.peers[peerPubKey]
|
||||
if !ok {
|
||||
return State{}, configurer.ErrPeerNotFound
|
||||
return State{}, iface.ErrPeerNotFound
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
@@ -412,7 +412,7 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
||||
}
|
||||
|
||||
// UpdateWireGuardPeerState updates the WireGuard bits of the peer state
|
||||
func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats configurer.WGStats) error {
|
||||
func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats iface.WGStats) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
|
||||
@@ -6,6 +6,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
||||
return stdnet.NewNet(ifaceBlacklist)
|
||||
func (w *WorkerICE) newStdNet() (*stdnet.Net, error) {
|
||||
return stdnet.NewNet(w.config.ICEConfig.InterfaceBlackList)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,6 @@ package peer
|
||||
|
||||
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
|
||||
func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
||||
return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist)
|
||||
func (w *WorkerICE) newStdNet() (*stdnet.Net, error) {
|
||||
return stdnet.NewNetWithDiscover(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList)
|
||||
}
|
||||
|
||||
@@ -15,9 +15,9 @@ import (
|
||||
"github.com/pion/stun/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/iface/bind"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -233,16 +233,41 @@ func (w *WorkerICE) Close() {
|
||||
}
|
||||
|
||||
func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport []ice.CandidateType) (*ice.Agent, error) {
|
||||
transportNet, err := newStdNet(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList)
|
||||
transportNet, err := w.newStdNet()
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to create pion's stdnet: %s", err)
|
||||
}
|
||||
|
||||
w.sentExtraSrflx = false
|
||||
iceKeepAlive := iceKeepAlive()
|
||||
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
||||
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
|
||||
|
||||
agent, err := newAgent(w.config, transportNet, relaySupport, w.localUfrag, w.localPwd)
|
||||
agentConfig := &ice.AgentConfig{
|
||||
MulticastDNSMode: ice.MulticastDNSModeDisabled,
|
||||
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
|
||||
Urls: w.config.ICEConfig.StunTurn.Load().([]*stun.URI),
|
||||
CandidateTypes: relaySupport,
|
||||
InterfaceFilter: stdnet.InterfaceFilter(w.config.ICEConfig.InterfaceBlackList),
|
||||
UDPMux: w.config.ICEConfig.UDPMux,
|
||||
UDPMuxSrflx: w.config.ICEConfig.UDPMuxSrflx,
|
||||
NAT1To1IPs: w.config.ICEConfig.NATExternalIPs,
|
||||
Net: transportNet,
|
||||
FailedTimeout: &failedTimeout,
|
||||
DisconnectedTimeout: &iceDisconnectedTimeout,
|
||||
KeepaliveInterval: &iceKeepAlive,
|
||||
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
|
||||
LocalUfrag: w.localUfrag,
|
||||
LocalPwd: w.localPwd,
|
||||
}
|
||||
|
||||
if w.config.ICEConfig.DisableIPv6Discovery {
|
||||
agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
|
||||
}
|
||||
|
||||
w.sentExtraSrflx = false
|
||||
agent, err := ice.NewAgent(agentConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create agent: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = agent.OnCandidate(w.onICECandidate)
|
||||
@@ -365,36 +390,6 @@ func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferA
|
||||
}
|
||||
}
|
||||
|
||||
func newAgent(config ConnConfig, transportNet *stdnet.Net, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) {
|
||||
iceKeepAlive := iceKeepAlive()
|
||||
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
||||
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
|
||||
|
||||
agentConfig := &ice.AgentConfig{
|
||||
MulticastDNSMode: ice.MulticastDNSModeDisabled,
|
||||
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
|
||||
Urls: config.ICEConfig.StunTurn.Load().([]*stun.URI),
|
||||
CandidateTypes: candidateTypes,
|
||||
InterfaceFilter: stdnet.InterfaceFilter(config.ICEConfig.InterfaceBlackList),
|
||||
UDPMux: config.ICEConfig.UDPMux,
|
||||
UDPMuxSrflx: config.ICEConfig.UDPMuxSrflx,
|
||||
NAT1To1IPs: config.ICEConfig.NATExternalIPs,
|
||||
Net: transportNet,
|
||||
FailedTimeout: &failedTimeout,
|
||||
DisconnectedTimeout: &iceDisconnectedTimeout,
|
||||
KeepaliveInterval: &iceKeepAlive,
|
||||
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
|
||||
LocalUfrag: ufrag,
|
||||
LocalPwd: pwd,
|
||||
}
|
||||
|
||||
if config.ICEConfig.DisableIPv6Discovery {
|
||||
agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
|
||||
}
|
||||
|
||||
return ice.NewAgent(agentConfig)
|
||||
}
|
||||
|
||||
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
|
||||
relatedAdd := candidate.RelatedAddress()
|
||||
return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
|
||||
|
||||
@@ -74,7 +74,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key)
|
||||
if err != nil {
|
||||
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
|
||||
w.log.Debugf("handled offer by reusing existing relay connection")
|
||||
w.log.Infof("do not need to reopen relay connection")
|
||||
return
|
||||
}
|
||||
w.log.Errorf("failed to open connection via Relay: %s", err)
|
||||
|
||||
@@ -10,12 +10,12 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -43,7 +43,7 @@ type clientNetwork struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
wgInterface iface.IWGIface
|
||||
routes map[route.ID]*route.Route
|
||||
routeUpdate chan routesUpdate
|
||||
peerStateUpdate chan struct{}
|
||||
@@ -53,7 +53,7 @@ type clientNetwork struct {
|
||||
updateSerial uint64
|
||||
}
|
||||
|
||||
func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface *iface.WGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
|
||||
func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface iface.IWGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
client := &clientNetwork{
|
||||
@@ -378,7 +378,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
}
|
||||
}
|
||||
|
||||
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface *iface.WGIface) RouteHandler {
|
||||
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface iface.IWGIface) RouteHandler {
|
||||
if rt.IsDynamic() {
|
||||
dns := nbdns.NewServiceViaMemory(wgInterface)
|
||||
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()))
|
||||
|
||||
@@ -13,10 +13,10 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -48,7 +48,7 @@ type Route struct {
|
||||
currentPeerKey string
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
wgInterface iface.IWGIface
|
||||
resolverAddr string
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ func NewRoute(
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
interval time.Duration,
|
||||
statusRecorder *peer.Status,
|
||||
wgInterface *iface.WGIface,
|
||||
wgInterface iface.IWGIface,
|
||||
resolverAddr string,
|
||||
) *Route {
|
||||
return &Route{
|
||||
@@ -303,7 +303,7 @@ func (r *Route) addRoutes(domain domain.Domain, prefixes []netip.Prefix) ([]neti
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, prefix := range prefixes {
|
||||
if _, err := r.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
|
||||
if _, err := r.routeRefCounter.Increment(prefix, nil); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add dynamic route for IP %s: %w", prefix, err))
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -14,8 +14,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||
@@ -23,6 +21,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
@@ -52,7 +51,7 @@ type DefaultManager struct {
|
||||
sysOps *systemops.SysOps
|
||||
statusRecorder *peer.Status
|
||||
relayMgr *relayClient.Manager
|
||||
wgInterface *iface.WGIface
|
||||
wgInterface iface.IWGIface
|
||||
pubKey string
|
||||
notifier *notifier.Notifier
|
||||
routeRefCounter *refcounter.RouteRefCounter
|
||||
@@ -64,7 +63,7 @@ func NewManager(
|
||||
ctx context.Context,
|
||||
pubKey string,
|
||||
dnsRouteInterval time.Duration,
|
||||
wgInterface *iface.WGIface,
|
||||
wgInterface iface.IWGIface,
|
||||
statusRecorder *peer.Status,
|
||||
relayMgr *relayClient.Manager,
|
||||
initialRoutes []*route.Route,
|
||||
@@ -88,10 +87,10 @@ func NewManager(
|
||||
}
|
||||
|
||||
dm.routeRefCounter = refcounter.New(
|
||||
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
||||
return struct{}{}, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface())
|
||||
func(prefix netip.Prefix, _ any) (any, error) {
|
||||
return nil, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface())
|
||||
},
|
||||
func(prefix netip.Prefix, _ struct{}) error {
|
||||
func(prefix netip.Prefix, _ any) error {
|
||||
return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface())
|
||||
},
|
||||
)
|
||||
@@ -103,7 +102,7 @@ func NewManager(
|
||||
},
|
||||
func(prefix netip.Prefix, peerKey string) error {
|
||||
if err := wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil {
|
||||
if !errors.Is(err, configurer.ErrPeerNotFound) && !errors.Is(err, configurer.ErrAllowedIPNotFound) {
|
||||
if !errors.Is(err, iface.ErrPeerNotFound) && !errors.Is(err, iface.ErrAllowedIPNotFound) {
|
||||
return err
|
||||
}
|
||||
log.Tracef("Remove allowed IPs %s for %s: %v", prefix, peerKey, err)
|
||||
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
|
||||
@@ -5,9 +5,9 @@ import (
|
||||
"fmt"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
@@ -3,8 +3,7 @@ package refcounter
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
@@ -13,153 +12,118 @@ import (
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
)
|
||||
|
||||
const logLevel = log.TraceLevel
|
||||
|
||||
// ErrIgnore can be returned by AddFunc to indicate that the counter should not be incremented for the given key.
|
||||
// ErrIgnore can be returned by AddFunc to indicate that the counter not be incremented for the given prefix.
|
||||
var ErrIgnore = errors.New("ignore")
|
||||
|
||||
// Ref holds the reference count and associated data for a key.
|
||||
type Ref[O any] struct {
|
||||
Count int
|
||||
Out O
|
||||
}
|
||||
|
||||
// AddFunc is the function type for adding a new key.
|
||||
// Key is the type of the key (e.g., netip.Prefix).
|
||||
type AddFunc[Key, I, O any] func(key Key, in I) (out O, err error)
|
||||
type AddFunc[I, O any] func(prefix netip.Prefix, in I) (out O, err error)
|
||||
type RemoveFunc[I, O any] func(prefix netip.Prefix, out O) error
|
||||
|
||||
// RemoveFunc is the function type for removing a key.
|
||||
type RemoveFunc[Key, O any] func(key Key, out O) error
|
||||
|
||||
// Counter is a generic reference counter for managing keys and their associated data.
|
||||
// Key: The type of the key (e.g., netip.Prefix, string).
|
||||
//
|
||||
// I: The input type for the AddFunc. It is the input type for additional data needed
|
||||
// when adding a key, it is passed as the second argument to AddFunc.
|
||||
//
|
||||
// O: The output type for the AddFunc and RemoveFunc. This is the output returned by AddFunc.
|
||||
// It is stored and passed to RemoveFunc when the reference count reaches 0.
|
||||
//
|
||||
// The types can be aliased to a specific type using the following syntax:
|
||||
//
|
||||
// type RouteRefCounter = Counter[netip.Prefix, any, any]
|
||||
type Counter[Key comparable, I, O any] struct {
|
||||
// refCountMap keeps track of the reference Ref for keys
|
||||
refCountMap map[Key]Ref[O]
|
||||
type Counter[I, O any] struct {
|
||||
// refCountMap keeps track of the reference Ref for prefixes
|
||||
refCountMap map[netip.Prefix]Ref[O]
|
||||
refCountMu sync.Mutex
|
||||
// idMap keeps track of the keys associated with an ID for removal
|
||||
idMap map[string][]Key
|
||||
// idMap keeps track of the prefixes associated with an ID for removal
|
||||
idMap map[string][]netip.Prefix
|
||||
idMu sync.Mutex
|
||||
add AddFunc[Key, I, O]
|
||||
remove RemoveFunc[Key, O]
|
||||
add AddFunc[I, O]
|
||||
remove RemoveFunc[I, O]
|
||||
}
|
||||
|
||||
// New creates a new Counter instance.
|
||||
// Usage example:
|
||||
//
|
||||
// counter := New[netip.Prefix, string, string](
|
||||
// func(key netip.Prefix, in string) (out string, err error) { ... },
|
||||
// func(key netip.Prefix, out string) error { ... },`
|
||||
// )
|
||||
func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key, O]) *Counter[Key, I, O] {
|
||||
return &Counter[Key, I, O]{
|
||||
refCountMap: map[Key]Ref[O]{},
|
||||
idMap: map[string][]Key{},
|
||||
// New creates a new Counter instance
|
||||
func New[I, O any](add AddFunc[I, O], remove RemoveFunc[I, O]) *Counter[I, O] {
|
||||
return &Counter[I, O]{
|
||||
refCountMap: map[netip.Prefix]Ref[O]{},
|
||||
idMap: map[string][]netip.Prefix{},
|
||||
add: add,
|
||||
remove: remove,
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves the current reference count and associated data for a key.
|
||||
// If the key doesn't exist, it returns a zero value Ref and false.
|
||||
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
|
||||
// Increment increments the reference count for the given prefix.
|
||||
// If this is the first reference to the prefix, the AddFunc is called.
|
||||
func (rm *Counter[I, O]) Increment(prefix netip.Prefix, in I) (Ref[O], error) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
|
||||
ref, ok := rm.refCountMap[key]
|
||||
return ref, ok
|
||||
}
|
||||
ref := rm.refCountMap[prefix]
|
||||
log.Tracef("Increasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
|
||||
|
||||
// Increment increments the reference count for the given key.
|
||||
// If this is the first reference to the key, the AddFunc is called.
|
||||
func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
|
||||
ref := rm.refCountMap[key]
|
||||
logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out)
|
||||
|
||||
// Call AddFunc only if it's a new key
|
||||
// Call AddFunc only if it's a new prefix
|
||||
if ref.Count == 0 {
|
||||
logCallerF("Calling add for key %v", key)
|
||||
out, err := rm.add(key, in)
|
||||
log.Tracef("Adding for prefix %s with [%v]", prefix, ref.Out)
|
||||
out, err := rm.add(prefix, in)
|
||||
|
||||
if errors.Is(err, ErrIgnore) {
|
||||
return ref, nil
|
||||
}
|
||||
if err != nil {
|
||||
return ref, fmt.Errorf("failed to add for key %v: %w", key, err)
|
||||
return ref, fmt.Errorf("failed to add for prefix %s: %w", prefix, err)
|
||||
}
|
||||
ref.Out = out
|
||||
}
|
||||
|
||||
ref.Count++
|
||||
rm.refCountMap[key] = ref
|
||||
rm.refCountMap[prefix] = ref
|
||||
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
// IncrementWithID increments the reference count for the given key and groups it under the given ID.
|
||||
// If this is the first reference to the key, the AddFunc is called.
|
||||
func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) {
|
||||
// IncrementWithID increments the reference count for the given prefix and groups it under the given ID.
|
||||
// If this is the first reference to the prefix, the AddFunc is called.
|
||||
func (rm *Counter[I, O]) IncrementWithID(id string, prefix netip.Prefix, in I) (Ref[O], error) {
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
ref, err := rm.Increment(key, in)
|
||||
ref, err := rm.Increment(prefix, in)
|
||||
if err != nil {
|
||||
return ref, fmt.Errorf("with ID: %w", err)
|
||||
}
|
||||
rm.idMap[id] = append(rm.idMap[id], key)
|
||||
rm.idMap[id] = append(rm.idMap[id], prefix)
|
||||
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
// Decrement decrements the reference count for the given key.
|
||||
// Decrement decrements the reference count for the given prefix.
|
||||
// If the reference count reaches 0, the RemoveFunc is called.
|
||||
func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
|
||||
func (rm *Counter[I, O]) Decrement(prefix netip.Prefix) (Ref[O], error) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
|
||||
ref, ok := rm.refCountMap[key]
|
||||
ref, ok := rm.refCountMap[prefix]
|
||||
if !ok {
|
||||
logCallerF("No reference found for key %v", key)
|
||||
log.Tracef("No reference found for prefix %s", prefix)
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
logCallerF("Decreasing ref count [%d -> %d] for key %v with Out [%v]", ref.Count, ref.Count-1, key, ref.Out)
|
||||
log.Tracef("Decreasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
|
||||
if ref.Count == 1 {
|
||||
logCallerF("Calling remove for key %v", key)
|
||||
if err := rm.remove(key, ref.Out); err != nil {
|
||||
return ref, fmt.Errorf("remove for key %v: %w", key, err)
|
||||
log.Tracef("Removing for prefix %s with [%v]", prefix, ref.Out)
|
||||
if err := rm.remove(prefix, ref.Out); err != nil {
|
||||
return ref, fmt.Errorf("remove for prefix %s: %w", prefix, err)
|
||||
}
|
||||
delete(rm.refCountMap, key)
|
||||
delete(rm.refCountMap, prefix)
|
||||
} else {
|
||||
ref.Count--
|
||||
rm.refCountMap[key] = ref
|
||||
rm.refCountMap[prefix] = ref
|
||||
}
|
||||
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
// DecrementWithID decrements the reference count for all keys associated with the given ID.
|
||||
// DecrementWithID decrements the reference count for all prefixes associated with the given ID.
|
||||
// If the reference count reaches 0, the RemoveFunc is called.
|
||||
func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
|
||||
func (rm *Counter[I, O]) DecrementWithID(id string) error {
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, key := range rm.idMap[id] {
|
||||
if _, err := rm.Decrement(key); err != nil {
|
||||
for _, prefix := range rm.idMap[id] {
|
||||
if _, err := rm.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
}
|
||||
@@ -168,77 +132,24 @@ func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// Flush removes all references and calls RemoveFunc for each key.
|
||||
func (rm *Counter[Key, I, O]) Flush() error {
|
||||
// Flush removes all references and calls RemoveFunc for each prefix.
|
||||
func (rm *Counter[I, O]) Flush() error {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for key := range rm.refCountMap {
|
||||
logCallerF("Calling remove for key %v", key)
|
||||
ref := rm.refCountMap[key]
|
||||
if err := rm.remove(key, ref.Out); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove for key %v: %w", key, err))
|
||||
for prefix := range rm.refCountMap {
|
||||
log.Tracef("Removing for prefix %s", prefix)
|
||||
ref := rm.refCountMap[prefix]
|
||||
if err := rm.remove(prefix, ref.Out); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove for prefix %s: %w", prefix, err))
|
||||
}
|
||||
}
|
||||
rm.refCountMap = map[netip.Prefix]Ref[O]{}
|
||||
|
||||
clear(rm.refCountMap)
|
||||
clear(rm.idMap)
|
||||
rm.idMap = map[string][]netip.Prefix{}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// Clear removes all references without calling RemoveFunc.
|
||||
func (rm *Counter[Key, I, O]) Clear() {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
clear(rm.refCountMap)
|
||||
clear(rm.idMap)
|
||||
}
|
||||
|
||||
func getCallerInfo(depth int, maxDepth int) (string, bool) {
|
||||
if depth >= maxDepth {
|
||||
return "", false
|
||||
}
|
||||
|
||||
pc, _, _, ok := runtime.Caller(depth)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if details := runtime.FuncForPC(pc); details != nil {
|
||||
name := details.Name()
|
||||
|
||||
lastDotIndex := strings.LastIndex(name, "/")
|
||||
if lastDotIndex != -1 {
|
||||
name = name[lastDotIndex+1:]
|
||||
}
|
||||
|
||||
if strings.HasPrefix(name, "refcounter.") {
|
||||
// +2 to account for recursion
|
||||
return getCallerInfo(depth+2, maxDepth)
|
||||
}
|
||||
|
||||
return name, true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
// logCaller logs a message with the package name and method of the function that called the current function.
|
||||
func logCallerF(format string, args ...interface{}) {
|
||||
if log.GetLevel() < logLevel {
|
||||
return
|
||||
}
|
||||
|
||||
if callerName, ok := getCallerInfo(3, 18); ok {
|
||||
format = fmt.Sprintf("[%s] %s", callerName, format)
|
||||
}
|
||||
|
||||
log.StandardLogger().Logf(logLevel, format, args...)
|
||||
}
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
package refcounter
|
||||
|
||||
import "net/netip"
|
||||
|
||||
// RouteRefCounter is a Counter for Route, it doesn't take any input on Increment and doesn't use any output on Decrement
|
||||
type RouteRefCounter = Counter[netip.Prefix, struct{}, struct{}]
|
||||
type RouteRefCounter = Counter[any, any]
|
||||
|
||||
// AllowedIPsRefCounter is a Counter for AllowedIPs, it takes a peer key on Increment and passes it back to Decrement
|
||||
type AllowedIPsRefCounter = Counter[netip.Prefix, string, string]
|
||||
type AllowedIPsRefCounter = Counter[string, string]
|
||||
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
"fmt"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
func newServerRouter(context.Context, *iface.WGIface, firewall.Manager, *peer.Status) (serverRouter, error) {
|
||||
func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (serverRouter, error) {
|
||||
return nil, fmt.Errorf("server route not supported on this os")
|
||||
}
|
||||
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -22,11 +22,11 @@ type defaultServerRouter struct {
|
||||
ctx context.Context
|
||||
routes map[route.ID]*route.Route
|
||||
firewall firewall.Manager
|
||||
wgInterface *iface.WGIface
|
||||
wgInterface iface.IWGIface
|
||||
statusRecorder *peer.Status
|
||||
}
|
||||
|
||||
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
|
||||
func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
|
||||
return &defaultServerRouter{
|
||||
ctx: ctx,
|
||||
routes: make(map[route.ID]*route.Route),
|
||||
@@ -94,7 +94,7 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error
|
||||
return fmt.Errorf("parse prefix: %w", err)
|
||||
}
|
||||
|
||||
err = m.firewall.RemoveNatRule(routerPair)
|
||||
err = m.firewall.RemoveRoutingRules(routerPair)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove routing rules: %w", err)
|
||||
}
|
||||
@@ -123,7 +123,7 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
|
||||
return fmt.Errorf("parse prefix: %w", err)
|
||||
}
|
||||
|
||||
err = m.firewall.AddNatRule(routerPair)
|
||||
err = m.firewall.InsertRoutingRules(routerPair)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert routing rules: %w", err)
|
||||
}
|
||||
@@ -157,7 +157,7 @@ func (m *defaultServerRouter) cleanUp() {
|
||||
continue
|
||||
}
|
||||
|
||||
err = m.firewall.RemoveNatRule(routerPair)
|
||||
err = m.firewall.RemoveRoutingRules(routerPair)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to remove cleanup route: %v", err)
|
||||
}
|
||||
@@ -173,15 +173,15 @@ func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) {
|
||||
// TODO: add ipv6
|
||||
source := getDefaultPrefix(route.Network)
|
||||
|
||||
destination := route.Network.Masked()
|
||||
destination := route.Network.Masked().String()
|
||||
if route.IsDynamic() {
|
||||
// TODO: add ipv6 additionally
|
||||
destination = getDefaultPrefix(destination)
|
||||
// TODO: add ipv6
|
||||
destination = "0.0.0.0/0"
|
||||
}
|
||||
|
||||
return firewall.RouterPair{
|
||||
ID: route.ID,
|
||||
Source: source,
|
||||
ID: string(route.ID),
|
||||
Source: source.String(),
|
||||
Destination: destination,
|
||||
Masquerade: route.Masquerade,
|
||||
}, nil
|
||||
|
||||
@@ -30,7 +30,7 @@ func (r *Route) String() string {
|
||||
}
|
||||
|
||||
func (r *Route) AddRoute(context.Context) error {
|
||||
_, err := r.routeRefCounter.Increment(r.route.Network, struct{}{})
|
||||
_, err := r.routeRefCounter.Increment(r.route.Network, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -23,7 +23,7 @@ const (
|
||||
)
|
||||
|
||||
// Setup configures sysctl settings for RP filtering and source validation.
|
||||
func Setup(wgIface *iface.WGIface) (map[string]int, error) {
|
||||
func Setup(wgIface iface.IWGIface) (map[string]int, error) {
|
||||
keys := map[string]int{}
|
||||
var result *multierror.Error
|
||||
|
||||
|
||||
@@ -5,9 +5,9 @@ import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
type Nexthop struct {
|
||||
@@ -15,11 +15,11 @@ type Nexthop struct {
|
||||
Intf *net.Interface
|
||||
}
|
||||
|
||||
type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop]
|
||||
type ExclusionCounter = refcounter.Counter[any, Nexthop]
|
||||
|
||||
type SysOps struct {
|
||||
refCounter *ExclusionCounter
|
||||
wgInterface *iface.WGIface
|
||||
wgInterface iface.IWGIface
|
||||
// prefixes is tracking all the current added prefixes im memory
|
||||
// (this is used in iOS as all route updates require a full table update)
|
||||
//nolint
|
||||
@@ -30,7 +30,7 @@ type SysOps struct {
|
||||
notifier *notifier.Notifier
|
||||
}
|
||||
|
||||
func NewSysOps(wgInterface *iface.WGIface, notifier *notifier.Notifier) *SysOps {
|
||||
func NewSysOps(wgInterface iface.IWGIface, notifier *notifier.Notifier) *SysOps {
|
||||
return &SysOps{
|
||||
wgInterface: wgInterface,
|
||||
notifier: notifier,
|
||||
|
||||
@@ -16,10 +16,10 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
@@ -41,7 +41,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn
|
||||
}
|
||||
|
||||
refCounter := refcounter.New(
|
||||
func(prefix netip.Prefix, _ struct{}) (Nexthop, error) {
|
||||
func(prefix netip.Prefix, _ any) (Nexthop, error) {
|
||||
initialNexthop := initialNextHopV4
|
||||
if prefix.Addr().Is6() {
|
||||
initialNexthop = initialNextHopV6
|
||||
@@ -122,7 +122,7 @@ func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
||||
|
||||
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
|
||||
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
|
||||
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop Nexthop) (Nexthop, error) {
|
||||
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.IWGIface, initialNextHop Nexthop) (Nexthop, error) {
|
||||
addr := prefix.Addr()
|
||||
switch {
|
||||
case addr.IsLoopback(),
|
||||
@@ -317,7 +317,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
|
||||
return fmt.Errorf("convert ip to prefix: %w", err)
|
||||
}
|
||||
|
||||
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil {
|
||||
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, nil); err != nil {
|
||||
return fmt.Errorf("adding route reference: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
type dialer interface {
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||
type ProxyWrapper struct {
|
||||
WgeBPFProxy *WGEBPFProxy
|
||||
|
||||
remoteConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
wgEndpointAddr *net.UDPAddr
|
||||
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
isStarted bool
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
|
||||
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add turn conn: %w", err)
|
||||
}
|
||||
p.remoteConn = remoteConn
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
p.wgEndpointAddr = addr
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
||||
return p.wgEndpointAddr
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) Work() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.paused = false
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
go p.proxyToLocal(p.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) Pause() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
|
||||
p.pausedMu.Lock()
|
||||
p.paused = true
|
||||
p.pausedMu.Unlock()
|
||||
}
|
||||
|
||||
// CloseConn close the remoteConn and automatically remove the conn instance from the map
|
||||
func (e *ProxyWrapper) CloseConn() error {
|
||||
if e.cancel == nil {
|
||||
return fmt.Errorf("proxy not started")
|
||||
}
|
||||
|
||||
e.cancel()
|
||||
|
||||
if err := e.remoteConn.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close remote conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
|
||||
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
n, err := p.readFromRemote(ctx, buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
if p.paused {
|
||||
p.pausedMu.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to write out turn pkg to local conn: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, error) {
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
if !errors.Is(err, io.EOF) {
|
||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
22
client/internal/wgproxy/factory.go
Normal file
22
client/internal/wgproxy/factory.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package wgproxy
|
||||
|
||||
import "context"
|
||||
|
||||
type Factory struct {
|
||||
wgPort int
|
||||
ebpfProxy Proxy
|
||||
}
|
||||
|
||||
func (w *Factory) GetProxy(ctx context.Context) Proxy {
|
||||
if w.ebpfProxy != nil {
|
||||
return w.ebpfProxy
|
||||
}
|
||||
return NewWGUserSpaceProxy(ctx, w.wgPort)
|
||||
}
|
||||
|
||||
func (w *Factory) Free() error {
|
||||
if w.ebpfProxy != nil {
|
||||
return w.ebpfProxy.Free()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -3,26 +3,20 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy/usp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Factory struct {
|
||||
wgPort int
|
||||
ebpfProxy *ebpf.WGEBPFProxy
|
||||
}
|
||||
|
||||
func NewFactory(userspace bool, wgPort int) *Factory {
|
||||
func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
|
||||
f := &Factory{wgPort: wgPort}
|
||||
|
||||
if userspace {
|
||||
return f
|
||||
}
|
||||
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort)
|
||||
err := ebpfProxy.Listen()
|
||||
ebpfProxy := NewWGEBPFProxy(ctx, wgPort)
|
||||
err := ebpfProxy.listen()
|
||||
if err != nil {
|
||||
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
|
||||
return f
|
||||
@@ -31,20 +25,3 @@ func NewFactory(userspace bool, wgPort int) *Factory {
|
||||
f.ebpfProxy = ebpfProxy
|
||||
return f
|
||||
}
|
||||
|
||||
func (w *Factory) GetProxy() Proxy {
|
||||
if w.ebpfProxy != nil {
|
||||
p := &ebpf.ProxyWrapper{
|
||||
WgeBPFProxy: w.ebpfProxy,
|
||||
}
|
||||
return p
|
||||
}
|
||||
return usp.NewWGUserSpaceProxy(w.wgPort)
|
||||
}
|
||||
|
||||
func (w *Factory) Free() error {
|
||||
if w.ebpfProxy == nil {
|
||||
return nil
|
||||
}
|
||||
return w.ebpfProxy.Free()
|
||||
}
|
||||
|
||||
@@ -2,20 +2,8 @@
|
||||
|
||||
package wgproxy
|
||||
|
||||
import "github.com/netbirdio/netbird/client/internal/wgproxy/usp"
|
||||
import "context"
|
||||
|
||||
type Factory struct {
|
||||
wgPort int
|
||||
}
|
||||
|
||||
func NewFactory(_ bool, wgPort int) *Factory {
|
||||
func NewFactory(ctx context.Context, _ bool, wgPort int) *Factory {
|
||||
return &Factory{wgPort: wgPort}
|
||||
}
|
||||
|
||||
func (w *Factory) GetProxy() Proxy {
|
||||
return usp.NewWGUserSpaceProxy(w.wgPort)
|
||||
}
|
||||
|
||||
func (w *Factory) Free() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package ebpf
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -1,4 +1,4 @@
|
||||
package ebpf
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -1,15 +1,12 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// Proxy is a transfer layer between the relayed connection and the WireGuard
|
||||
// Proxy is a transfer layer between the Turn connection and the WireGuard
|
||||
type Proxy interface {
|
||||
AddTurnConn(ctx context.Context, turnConn net.Conn) error
|
||||
EndpointAddr() *net.UDPAddr
|
||||
Work()
|
||||
Pause()
|
||||
AddTurnConn(turnConn net.Conn) (net.Addr, error)
|
||||
CloseConn() error
|
||||
Free() error
|
||||
}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package ebpf
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
@@ -12,49 +13,47 @@ import (
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
loopbackAddr = "127.0.0.1"
|
||||
)
|
||||
|
||||
// WGEBPFProxy definition for proxy with EBPF support
|
||||
type WGEBPFProxy struct {
|
||||
ebpfManager ebpfMgr.Manager
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
lastUsedPort uint16
|
||||
localWGListenPort int
|
||||
|
||||
ebpfManager ebpfMgr.Manager
|
||||
turnConnStore map[uint16]net.Conn
|
||||
turnConnMutex sync.Mutex
|
||||
|
||||
lastUsedPort uint16
|
||||
rawConn net.PacketConn
|
||||
conn transport.UDPConn
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
rawConn net.PacketConn
|
||||
conn transport.UDPConn
|
||||
}
|
||||
|
||||
// NewWGEBPFProxy create new WGEBPFProxy instance
|
||||
func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
|
||||
func NewWGEBPFProxy(ctx context.Context, wgPort int) *WGEBPFProxy {
|
||||
log.Debugf("instantiate ebpf proxy")
|
||||
wgProxy := &WGEBPFProxy{
|
||||
localWGListenPort: wgPort,
|
||||
ebpfManager: ebpf.GetEbpfManagerInstance(),
|
||||
lastUsedPort: 0,
|
||||
turnConnStore: make(map[uint16]net.Conn),
|
||||
}
|
||||
wgProxy.ctx, wgProxy.cancel = context.WithCancel(ctx)
|
||||
|
||||
return wgProxy
|
||||
}
|
||||
|
||||
// Listen load ebpf program and listen the proxy
|
||||
func (p *WGEBPFProxy) Listen() error {
|
||||
// listen load ebpf program and listen the proxy
|
||||
func (p *WGEBPFProxy) listen() error {
|
||||
pl := portLookup{}
|
||||
wgPorxyPort, err := pl.searchFreePort()
|
||||
if err != nil {
|
||||
@@ -73,14 +72,13 @@ func (p *WGEBPFProxy) Listen() error {
|
||||
|
||||
addr := net.UDPAddr{
|
||||
Port: wgPorxyPort,
|
||||
IP: net.ParseIP(loopbackAddr),
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
}
|
||||
|
||||
p.ctx, p.ctxCancel = context.WithCancel(context.Background())
|
||||
|
||||
conn, err := nbnet.ListenUDP("udp", &addr)
|
||||
if err != nil {
|
||||
if cErr := p.Free(); cErr != nil {
|
||||
cErr := p.Free()
|
||||
if cErr != nil {
|
||||
log.Errorf("Failed to close the wgproxy: %s", cErr)
|
||||
}
|
||||
return err
|
||||
@@ -93,84 +91,108 @@ func (p *WGEBPFProxy) Listen() error {
|
||||
}
|
||||
|
||||
// AddTurnConn add new turn connection for the proxy
|
||||
func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (*net.UDPAddr, error) {
|
||||
func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
|
||||
wgEndpointPort, err := p.storeTurnConn(turnConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go p.proxyToLocal(wgEndpointPort, turnConn)
|
||||
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
|
||||
|
||||
wgEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP(loopbackAddr),
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: int(wgEndpointPort),
|
||||
}
|
||||
return wgEndpoint, nil
|
||||
}
|
||||
|
||||
// Free resources except the remoteConns will be keep open.
|
||||
// CloseConn doing nothing because this type of proxy implementation does not store the connection
|
||||
func (p *WGEBPFProxy) CloseConn() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Free resources
|
||||
func (p *WGEBPFProxy) Free() error {
|
||||
log.Debugf("free up ebpf wg proxy")
|
||||
if p.ctx != nil && p.ctx.Err() != nil {
|
||||
//nolint
|
||||
return nil
|
||||
var err1, err2, err3 error
|
||||
if p.conn != nil {
|
||||
err1 = p.conn.Close()
|
||||
}
|
||||
|
||||
p.ctxCancel()
|
||||
err2 = p.ebpfManager.FreeWGProxy()
|
||||
if p.rawConn != nil {
|
||||
err3 = p.rawConn.Close()
|
||||
}
|
||||
|
||||
var result *multierror.Error
|
||||
if p.conn != nil { // p.conn will be nil if we have failed to listen
|
||||
if err := p.conn.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
|
||||
if err2 != nil {
|
||||
return err2
|
||||
}
|
||||
|
||||
return err3
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
|
||||
buf := make([]byte, 1500)
|
||||
var err error
|
||||
defer func() {
|
||||
p.removeTurnConn(endpointPort)
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
default:
|
||||
var n int
|
||||
n, err = remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
err = p.sendPkg(buf[:n], endpointPort)
|
||||
if err != nil {
|
||||
log.Errorf("failed to write out turn pkg to local conn: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := p.ebpfManager.FreeWGProxy(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
if err := p.rawConn.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
|
||||
// From this go routine has only one instance.
|
||||
func (p *WGEBPFProxy) proxyToRemote() {
|
||||
buf := make([]byte, 1500)
|
||||
for p.ctx.Err() == nil {
|
||||
if err := p.readAndForwardPacket(buf); err != nil {
|
||||
if p.ctx.Err() != nil {
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
default:
|
||||
n, addr, err := p.conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
log.Errorf("failed to read UDP pkg from WG: %s", err)
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to proxy packet to remote conn: %s", err)
|
||||
|
||||
p.turnConnMutex.Lock()
|
||||
conn, ok := p.turnConnStore[uint16(addr.Port)]
|
||||
p.turnConnMutex.Unlock()
|
||||
if !ok {
|
||||
log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port)
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = conn.Write(buf[:n])
|
||||
if err != nil {
|
||||
log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) readAndForwardPacket(buf []byte) error {
|
||||
n, addr, err := p.conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read UDP packet from WG: %w", err)
|
||||
}
|
||||
|
||||
p.turnConnMutex.Lock()
|
||||
conn, ok := p.turnConnStore[uint16(addr.Port)]
|
||||
p.turnConnMutex.Unlock()
|
||||
if !ok {
|
||||
if p.ctx.Err() == nil {
|
||||
log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := conn.Write(buf[:n]); err != nil {
|
||||
return fmt.Errorf("failed to forward local WG packet (%d) to remote turn conn: %w", addr.Port, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
|
||||
p.turnConnMutex.Lock()
|
||||
defer p.turnConnMutex.Unlock()
|
||||
@@ -184,14 +206,11 @@ func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) {
|
||||
log.Debugf("remove turn conn from store by port: %d", turnConnID)
|
||||
p.turnConnMutex.Lock()
|
||||
defer p.turnConnMutex.Unlock()
|
||||
|
||||
_, ok := p.turnConnStore[turnConnID]
|
||||
if ok {
|
||||
log.Debugf("remove turn conn from store by port: %d", turnConnID)
|
||||
}
|
||||
delete(p.turnConnStore, turnConnID)
|
||||
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) nextFreePort() (uint16, error) {
|
||||
@@ -249,7 +268,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
|
||||
func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error {
|
||||
localhost := net.ParseIP("127.0.0.1")
|
||||
|
||||
payload := gopacket.Payload(data)
|
||||
@@ -1,13 +1,14 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package ebpf
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWGEBPFProxy_connStore(t *testing.T) {
|
||||
wgProxy := NewWGEBPFProxy(1)
|
||||
wgProxy := NewWGEBPFProxy(context.Background(), 1)
|
||||
|
||||
p, _ := wgProxy.storeTurnConn(nil)
|
||||
if p != 1 {
|
||||
@@ -27,7 +28,7 @@ func TestWGEBPFProxy_connStore(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
|
||||
wgProxy := NewWGEBPFProxy(1)
|
||||
wgProxy := NewWGEBPFProxy(context.Background(), 1)
|
||||
|
||||
_, _ = wgProxy.storeTurnConn(nil)
|
||||
wgProxy.lastUsedPort = 65535
|
||||
@@ -43,7 +44,7 @@ func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) {
|
||||
wgProxy := NewWGEBPFProxy(1)
|
||||
wgProxy := NewWGEBPFProxy(context.Background(), 1)
|
||||
|
||||
for i := 0; i < 65535; i++ {
|
||||
_, _ = wgProxy.storeTurnConn(nil)
|
||||
@@ -1,128 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy/usp"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
_ = util.InitLog("trace", "console")
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
type mocConn struct {
|
||||
closeChan chan struct{}
|
||||
closed bool
|
||||
}
|
||||
|
||||
func newMockConn() *mocConn {
|
||||
return &mocConn{
|
||||
closeChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mocConn) Read(b []byte) (n int, err error) {
|
||||
<-m.closeChan
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (m *mocConn) Write(b []byte) (n int, err error) {
|
||||
<-m.closeChan
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (m *mocConn) Close() error {
|
||||
if m.closed == true {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.closed = true
|
||||
close(m.closeChan)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mocConn) LocalAddr() net.Addr {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mocConn) RemoteAddr() net.Addr {
|
||||
return &net.UDPAddr{
|
||||
IP: net.ParseIP("172.16.254.1"),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mocConn) SetDeadline(t time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mocConn) SetReadDeadline(t time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mocConn) SetWriteDeadline(t time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func TestProxyCloseByRemoteConn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
proxy Proxy
|
||||
}{
|
||||
{
|
||||
name: "userspace proxy",
|
||||
proxy: usp.NewWGUserSpaceProxy(51830),
|
||||
},
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" {
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(51831)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %s", err)
|
||||
}
|
||||
}()
|
||||
proxyWrapper := &ebpf.ProxyWrapper{
|
||||
WgeBPFProxy: ebpfProxy,
|
||||
}
|
||||
|
||||
tests = append(tests, struct {
|
||||
name string
|
||||
proxy Proxy
|
||||
}{
|
||||
name: "ebpf proxy",
|
||||
proxy: proxyWrapper,
|
||||
})
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
relayedConn := newMockConn()
|
||||
err := tt.proxy.AddTurnConn(ctx, relayedConn)
|
||||
if err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
|
||||
_ = relayedConn.Close()
|
||||
if err := tt.proxy.CloseConn(); err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
130
client/internal/wgproxy/proxy_userspace.go
Normal file
130
client/internal/wgproxy/proxy_userspace.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// WGUserSpaceProxy proxies
|
||||
type WGUserSpaceProxy struct {
|
||||
localWGListenPort int
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
remoteConn net.Conn
|
||||
localConn net.Conn
|
||||
}
|
||||
|
||||
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy
|
||||
func NewWGUserSpaceProxy(ctx context.Context, wgPort int) *WGUserSpaceProxy {
|
||||
log.Debugf("Initializing new user space proxy with port %d", wgPort)
|
||||
p := &WGUserSpaceProxy{
|
||||
localWGListenPort: wgPort,
|
||||
}
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
return p
|
||||
}
|
||||
|
||||
// AddTurnConn start the proxy with the given remote conn
|
||||
func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) {
|
||||
p.remoteConn = remoteConn
|
||||
|
||||
var err error
|
||||
p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||
if err != nil {
|
||||
log.Errorf("failed dialing to local Wireguard port %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go p.proxyToRemote()
|
||||
go p.proxyToLocal()
|
||||
|
||||
return p.localConn.LocalAddr(), err
|
||||
}
|
||||
|
||||
// CloseConn close the localConn
|
||||
func (p *WGUserSpaceProxy) CloseConn() error {
|
||||
p.cancel()
|
||||
if p.localConn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if p.remoteConn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := p.remoteConn.Close(); err != nil {
|
||||
log.Warnf("failed to close remote conn: %s", err)
|
||||
}
|
||||
return p.localConn.Close()
|
||||
}
|
||||
|
||||
// Free doing nothing because this implementation of proxy does not have global state
|
||||
func (p *WGUserSpaceProxy) Free() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
|
||||
// blocks
|
||||
func (p *WGUserSpaceProxy) proxyToRemote() {
|
||||
defer log.Infof("exit from proxyToRemote: %s", p.localConn.LocalAddr())
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
default:
|
||||
n, err := p.localConn.Read(buf)
|
||||
if err != nil {
|
||||
log.Debugf("failed to read from wg interface conn: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = p.remoteConn.Write(buf[:n])
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
p.cancel()
|
||||
} else {
|
||||
log.Debugf("failed to write to remote conn: %s", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
|
||||
// blocks
|
||||
func (p *WGUserSpaceProxy) proxyToLocal() {
|
||||
defer p.cancel()
|
||||
defer log.Infof("exit from proxyToLocal: %s", p.localConn.LocalAddr())
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
default:
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to read from remote conn: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = p.localConn.Write(buf[:n])
|
||||
if err != nil {
|
||||
log.Debugf("failed to write to wg interface conn: %s", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,195 +0,0 @@
|
||||
package usp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/errors"
|
||||
)
|
||||
|
||||
// WGUserSpaceProxy proxies
|
||||
type WGUserSpaceProxy struct {
|
||||
localWGListenPort int
|
||||
|
||||
remoteConn net.Conn
|
||||
localConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
isStarted bool
|
||||
}
|
||||
|
||||
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
|
||||
func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
|
||||
log.Debugf("Initializing new user space proxy with port %d", wgPort)
|
||||
p := &WGUserSpaceProxy{
|
||||
localWGListenPort: wgPort,
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// AddTurnConn
|
||||
// The provided Context must be non-nil. If the context expires before
|
||||
// the connection is complete, an error is returned. Once successfully
|
||||
// connected, any expiration of the context will not affect the
|
||||
// connection.
|
||||
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
|
||||
dialer := net.Dialer{}
|
||||
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||
if err != nil {
|
||||
log.Errorf("failed dialing to local Wireguard port %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
p.localConn = localConn
|
||||
p.remoteConn = remoteConn
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr {
|
||||
if p.localConn == nil {
|
||||
return nil
|
||||
}
|
||||
endpointUdpAddr, _ := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
|
||||
return endpointUdpAddr
|
||||
}
|
||||
|
||||
// Work starts the proxy or resumes it if it was paused
|
||||
func (p *WGUserSpaceProxy) Work() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.paused = false
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
go p.proxyToRemote(p.ctx)
|
||||
go p.proxyToLocal(p.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// Pause pauses the proxy from receiving data from the remote peer
|
||||
func (p *WGUserSpaceProxy) Pause() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.paused = true
|
||||
p.pausedMu.Unlock()
|
||||
}
|
||||
|
||||
// CloseConn close the localConn
|
||||
func (p *WGUserSpaceProxy) CloseConn() error {
|
||||
if p.cancel == nil {
|
||||
return fmt.Errorf("proxy not started")
|
||||
}
|
||||
return p.close()
|
||||
}
|
||||
|
||||
func (p *WGUserSpaceProxy) close() error {
|
||||
p.closeMu.Lock()
|
||||
defer p.closeMu.Unlock()
|
||||
|
||||
// prevent double close
|
||||
if p.closed {
|
||||
return nil
|
||||
}
|
||||
p.closed = true
|
||||
|
||||
p.cancel()
|
||||
|
||||
var result *multierror.Error
|
||||
if err := p.remoteConn.Close(); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
|
||||
}
|
||||
|
||||
if err := p.localConn.Close(); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
|
||||
}
|
||||
return errors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// proxyToRemote proxies from Wireguard to the RemoteKey
|
||||
func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) {
|
||||
defer func() {
|
||||
if err := p.close(); err != nil {
|
||||
log.Warnf("error in proxy to remote loop: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for ctx.Err() == nil {
|
||||
n, err := p.localConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Debugf("failed to read from wg interface conn: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = p.remoteConn.Write(buf[:n])
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("failed to write to remote conn: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// proxyToLocal proxies from the Remote peer to local WireGuard
|
||||
// if the proxy is paused it will drain the remote conn and drop the packets
|
||||
func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) {
|
||||
defer func() {
|
||||
if err := p.close(); err != nil {
|
||||
log.Warnf("error in proxy to local loop: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
if p.paused {
|
||||
p.pausedMu.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = p.localConn.Write(buf[:n])
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Debugf("failed to write to wg interface conn: %s", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,41 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/cmd"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// only if env is not set
|
||||
if os.Getenv("NB_LOG_LEVEL") == "" {
|
||||
if err := os.Setenv("NB_LOG_LEVEL", "debug"); err != nil {
|
||||
log.Errorf("Failed setting log-level: %v", err)
|
||||
}
|
||||
}
|
||||
if err := os.Setenv("NB_LOG_MAX_SIZE_MB", "100"); err != nil {
|
||||
log.Errorf("Failed setting log-size: %v", err)
|
||||
}
|
||||
if err := os.Setenv("NB_WINDOWS_PANIC_LOG", filepath.Join(os.Getenv("ProgramData"), "netbird", "netbird.err")); err != nil {
|
||||
log.Errorf("Failed setting panic log path: %v", err)
|
||||
}
|
||||
|
||||
go startPprofServer()
|
||||
|
||||
if err := cmd.Execute(); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func startPprofServer() {
|
||||
pprofAddr := "localhost:6969"
|
||||
log.Infof("Starting pprof debugging server on %s", pprofAddr)
|
||||
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
|
||||
log.Infof("pprof server failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -110,6 +110,35 @@ func (s *Server) Start() error {
|
||||
ctx, cancel := context.WithCancel(s.rootCtx)
|
||||
s.actCancel = cancel
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if statusResp, err := s.Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true}); err != nil {
|
||||
log.Infof("Error getting status: %v", err)
|
||||
} else if statusResp.FullStatus != nil {
|
||||
log.Infof("Status --------")
|
||||
for _, peer := range statusResp.FullStatus.Peers {
|
||||
log.Infof("[Peer Connection] Name: %s, IP: %s, Key: %s, Connection Status: %s, Relayed: %v, RelayedAddress: %v, Last WireGuard Handshake: %v",
|
||||
peer.Fqdn,
|
||||
peer.IP,
|
||||
peer.PubKey,
|
||||
peer.ConnStatus,
|
||||
peer.Relayed,
|
||||
peer.RelayAddress,
|
||||
peer.LastWireguardHandshake.AsTime().Format("15:04:05"),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// if configuration exists, we just start connections. if is new config we skip and set status NeedsLogin
|
||||
// on failure we return error to retry
|
||||
config, err := internal.UpdateConfig(s.latestConfigInput)
|
||||
|
||||
@@ -110,7 +110,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), "", config.Datadir)
|
||||
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@@ -160,7 +160,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
|
||||
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
||||
srv, err := signalServer.NewServer(otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
proto.RegisterSignalExchangeServer(s, srv)
|
||||
|
||||
|
||||
38
client/testdata/store.json
vendored
Normal file
38
client/testdata/store.json
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
{
|
||||
"Accounts": {
|
||||
"bf1c8084-ba50-4ce7-9439-34653001fc3b": {
|
||||
"Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
"SetupKeys": {
|
||||
"A2C8E62B-38F5-4553-B31E-DD66C696CEBB": {
|
||||
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
|
||||
"Name": "Default key",
|
||||
"Type": "reusable",
|
||||
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
|
||||
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
|
||||
"Revoked": false,
|
||||
"UsedTimes": 0
|
||||
|
||||
}
|
||||
},
|
||||
"Network": {
|
||||
"Id": "af1c8024-ha40-4ce2-9418-34653101fc3c",
|
||||
"Net": {
|
||||
"IP": "100.64.0.0",
|
||||
"Mask": "//8AAA=="
|
||||
},
|
||||
"Dns": null
|
||||
},
|
||||
"Peers": {},
|
||||
"Users": {
|
||||
"edafee4e-63fb-11ec-90d6-0242ac120003": {
|
||||
"Id": "edafee4e-63fb-11ec-90d6-0242ac120003",
|
||||
"Role": "admin"
|
||||
},
|
||||
"f4f6d672-63fb-11ec-90d6-0242ac120003": {
|
||||
"Id": "f4f6d672-63fb-11ec-90d6-0242ac120003",
|
||||
"Role": "user"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
BIN
client/testdata/store.sqlite
vendored
BIN
client/testdata/store.sqlite
vendored
Binary file not shown.
4
go.mod
4
go.mod
@@ -59,8 +59,8 @@ require (
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/nadoo/ipset v0.5.0
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/oschwald/maxminddb-golang v1.12.0
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
|
||||
8
go.sum
8
go.sum
@@ -521,12 +521,12 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd h1:phKq1S1Y/lnqEhP5Qknta733+rPX16dRDHM7hKkot9c=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f h1:Rl23OSc2xKFyxiuBXtWDMzhZBV4gOM7lhFxvYoCmBZg=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080 h1:mXJkoWLdqJTlkQ7DgQ536kcXHXIdUPeagkN8i4eFDdg=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
package device
|
||||
package iface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
// WGAddress WireGuard parsed address
|
||||
// WGAddress Wireguard parsed address
|
||||
type WGAddress struct {
|
||||
IP net.IP
|
||||
Network *net.IPNet
|
||||
}
|
||||
|
||||
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
||||
func ParseWGAddress(address string) (WGAddress, error) {
|
||||
// parseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
||||
func parseWGAddress(address string) (WGAddress, error) {
|
||||
ip, network, err := net.ParseCIDR(address)
|
||||
if err != nil {
|
||||
return WGAddress{}, err
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user