Compare commits

..

8 Commits

Author SHA1 Message Date
Zoltan Papp
d68a4a7d21 Replace the grpc header key to NetBird specific 2023-06-28 12:58:05 +02:00
Zoltan Papp
ca1722ed10 Handle the stream sending in thread safe way 2023-06-28 02:08:09 +02:00
Zoltan Papp
649dbf2bed Change log line 2023-06-26 12:47:06 +02:00
Zoltan Papp
70076b98d2 Fix metadata preparation in signal 2023-06-23 13:07:14 +02:00
Zoltan Papp
551455f314 Handle keep alive in signal server 2023-06-23 13:07:12 +02:00
pzoli
a6431e053b Update protobuf 2023-06-23 13:03:34 +02:00
Zoltan Papp
520c7b5d37 Move keepalive out of mgm pkg 2023-06-23 13:02:50 +02:00
Zoltan Papp
e376541745 Add grpc keep alive for management service 2023-06-23 13:01:31 +02:00
372 changed files with 9263 additions and 22159 deletions

View File

@@ -1,15 +0,0 @@
FROM golang:1.20-bullseye
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
&& apt-get -y install --no-install-recommends\
gettext-base=0.21-4 \
iptables=1.8.7-1 \
libgl1-mesa-dev=20.3.5-1 \
xorg-dev=1:7.7+22 \
libayatana-appindicator3-dev=0.5.5-2+deb11u2 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* \
&& go install -v golang.org/x/tools/gopls@latest
WORKDIR /app

View File

@@ -1,20 +0,0 @@
{
"name": "NetBird",
"build": {
"context": "..",
"dockerfile": "Dockerfile"
},
"features": {
"ghcr.io/devcontainers/features/docker-in-docker:2": {},
"ghcr.io/devcontainers/features/go:1": {
"version": "1.20"
}
},
"workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}",
"capAdd": [
"NET_ADMIN",
"SYS_ADMIN",
"SYS_RESOURCE"
],
"privileged": true
}

1
.gitattributes vendored
View File

@@ -1 +0,0 @@
*.go text eol=lf

View File

@@ -1,41 +0,0 @@
name: Android build validation
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Install Go
uses: actions/setup-go@v4
with:
go-version: "1.20.x"
- name: Setup Android SDK
uses: android-actions/setup-android@v2
- name: NDK Cache
id: ndk-cache
uses: actions/cache@v3
with:
path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620
- name: Setup NDK
run: /usr/local/lib/android/sdk/tools/bin/sdkmanager --install "ndk;23.1.7779620"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
- name: gomobile init
run: gomobile init
- name: build android nebtird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
env:
CGO_ENABLED: 0
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620

View File

@@ -12,20 +12,17 @@ concurrency:
jobs:
test:
strategy:
matrix:
store: ['jsonfile', 'sqlite']
runs-on: macos-latest
steps:
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v2
with:
go-version: "1.20.x"
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v2
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v2
with:
path: ~/go/pkg/mod
key: macos-go-${{ hashFiles('**/go.sum') }}
@@ -36,4 +33,4 @@ jobs:
run: go mod tidy
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
run: go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...

View File

@@ -15,17 +15,16 @@ jobs:
strategy:
matrix:
arch: ['386','amd64']
store: ['jsonfile', 'sqlite']
runs-on: ubuntu-latest
steps:
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v2
with:
go-version: "1.20.x"
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v2
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -33,7 +32,7 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v2
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
@@ -42,18 +41,19 @@ jobs:
run: go mod tidy
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
test_client_on_docker:
runs-on: ubuntu-20.04
runs-on: ubuntu-latest
steps:
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v2
with:
go-version: "1.20.x"
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v2
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -61,10 +61,10 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v2
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
- name: Install modules
run: go mod tidy
@@ -82,7 +82,7 @@ jobs:
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
- name: Generate Engine Test bin
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
run: CGO_ENABLED=0 go test -c -o engine-testing.bin ./client/internal
- name: Generate Peer Test bin
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/...
@@ -95,17 +95,15 @@ jobs:
- name: Run Iface tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1
- name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
- name: Run nftables Manager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Engine tests in docker with file store
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Engine tests in docker with sqlite store
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Engine tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Peer tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -39,9 +39,7 @@ jobs:
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'
- run: choco install -y sysinternals --ignore-checksums
- run: choco install -y mingw
- run: choco install -y sysinternals
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build

View File

@@ -1,48 +1,21 @@
name: golangci-lint
on: [pull_request]
permissions:
contents: read
pull-requests: read
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
codespell:
name: codespell
golangci:
name: lint
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta
skip: go.mod,go.sum
only_warn: 1
golangci:
strategy:
fail-fast: false
matrix:
os: [macos-latest, windows-latest, ubuntu-latest]
name: lint
runs-on: ${{ matrix.os }}
timeout-minutes: 15
steps:
- name: Checkout code
uses: actions/checkout@v3
- uses: actions/checkout@v2
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v2
with:
go-version: "1.20.x"
cache: false
- name: Install dependencies
if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
- name: golangci-lint
uses: golangci/golangci-lint-action@v3
uses: golangci/golangci-lint-action@v2
with:
version: latest
args: --timeout=12m
args: --timeout=6m

View File

@@ -1,36 +0,0 @@
name: Test installation
on:
push:
branches:
- main
pull_request:
paths:
- "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test-install-script:
strategy:
max-parallel: 2
matrix:
os: [ubuntu-latest, macos-latest]
skip_ui_mode: [true, false]
install_binary: [true, false]
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: run install script
env:
SKIP_UI_APP: ${{ matrix.skip_ui_mode }}
USE_BIN_INSTALL: ${{ matrix.install_binary }}
GITHUB_TOKEN: ${{ secrets.RO_API_CALLER_TOKEN }}
run: |
[ "$SKIP_UI_APP" == "false" ] && export XDG_CURRENT_DESKTOP="none"
cat release_files/install.sh | sh -x
- name: check cli binary
run: command -v netbird

View File

@@ -0,0 +1,60 @@
name: Test installation Darwin
on:
push:
branches:
- main
pull_request:
paths:
- "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
install-cli-only:
runs-on: macos-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Rename brew package
if: ${{ matrix.check_bin_install }}
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
- name: Run install script
run: |
sh ./release_files/install.sh
env:
SKIP_UI_APP: true
- name: Run tests
run: |
if ! command -v netbird &> /dev/null; then
echo "Error: netbird is not installed"
exit 1
fi
install-all:
runs-on: macos-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Rename brew package
if: ${{ matrix.check_bin_install }}
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
- name: Run install script
run: |
sh ./release_files/install.sh
- name: Run tests
run: |
if ! command -v netbird &> /dev/null; then
echo "Error: netbird is not installed"
exit 1
fi
if [[ $(mdfind "kMDItemContentType == 'com.apple.application-bundle' && kMDItemFSName == '*NetBird UI.app'") ]]; then
echo "Error: NetBird UI is not installed"
exit 1
fi

View File

@@ -0,0 +1,38 @@
name: Test installation Linux
on:
push:
branches:
- main
pull_request:
paths:
- "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
install-cli-only:
runs-on: ubuntu-latest
strategy:
matrix:
check_bin_install: [true, false]
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Rename apt package
if: ${{ matrix.check_bin_install }}
run: |
sudo mv /usr/bin/apt /usr/bin/apt.bak
sudo mv /usr/bin/apt-get /usr/bin/apt-get.bak
- name: Run install script
run: |
sh ./release_files/install.sh
- name: Run tests
run: |
if ! command -v netbird &> /dev/null; then
echo "Error: netbird is not installed"
exit 1
fi

View File

@@ -7,20 +7,9 @@ on:
branches:
- main
pull_request:
paths:
- 'go.mod'
- 'go.sum'
- '.goreleaser.yml'
- '.goreleaser_ui.yaml'
- '.goreleaser_ui_darwin.yaml'
- '.github/workflows/release.yml'
- 'release_files/**'
- '**/Dockerfile'
- '**/Dockerfile.*'
- 'client/ui/**'
env:
SIGN_PIPE_VER: "v0.0.10"
SIGN_PIPE_VER: "v0.0.8"
GORELEASER_VER: "v1.14.1"
concurrency:
@@ -30,24 +19,20 @@ concurrency:
jobs:
release:
runs-on: ubuntu-latest
env:
flags: ""
steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
-
name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v2
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
-
name: Set up Go
uses: actions/setup-go@v4
uses: actions/setup-go@v2
with:
go-version: "1.20"
-
name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v1
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -61,10 +46,10 @@ jobs:
run: git --no-pager diff --exit-code
-
name: Set up QEMU
uses: docker/setup-qemu-action@v2
uses: docker/setup-qemu-action@v1
-
name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
uses: docker/setup-buildx-action@v1
-
name: Login to Docker hub
if: github.event_name != 'pull_request'
@@ -87,10 +72,10 @@ jobs:
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
-
name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4
uses: goreleaser/goreleaser-action@v2
with:
version: ${{ env.GORELEASER_VER }}
args: release --rm-dist ${{ env.flags }}
args: release --rm-dist
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
@@ -98,7 +83,7 @@ jobs:
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
-
name: upload non tags for debug purposes
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v2
with:
name: release
path: dist/
@@ -107,19 +92,17 @@ jobs:
release_ui:
runs-on: ubuntu-latest
steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v2
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go
uses: actions/setup-go@v4
uses: actions/setup-go@v2
with:
go-version: "1.20"
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v1
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }}
@@ -133,23 +116,23 @@ jobs:
run: git --no-pager diff --exit-code
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-mingw-w64-x86-64
- name: Install rsrc
run: go install github.com/akavel/rsrc@v0.10.2
- name: Generate windows rsrc
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4
uses: goreleaser/goreleaser-action@v2
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }}
args: release --config .goreleaser_ui.yaml --rm-dist
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- name: upload non tags for debug purposes
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v2
with:
name: release-ui
path: dist/
@@ -158,21 +141,19 @@ jobs:
release_ui_darwin:
runs-on: macos-11
steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
-
name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v2
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
-
name: Set up Go
uses: actions/setup-go@v4
uses: actions/setup-go@v2
with:
go-version: "1.20"
-
name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v1
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }}
@@ -184,15 +165,15 @@ jobs:
-
name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@v4
uses: goreleaser/goreleaser-action@v2
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }}
args: release --config .goreleaser_ui_darwin.yaml --rm-dist
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-
name: upload non tags for debug purposes
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v2
with:
name: release-ui-darwin
path: dist/

View File

@@ -1,22 +1,18 @@
name: Test Infrastructure files
name: Test Docker Compose Linux
on:
push:
branches:
- main
pull_request:
paths:
- 'infrastructure_files/**'
- '.github/workflows/test-infrastructure-files.yml'
- 'management/cmd/**'
- 'signal/cmd/**'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test-docker-compose:
test:
runs-on: ubuntu-latest
steps:
- name: Install jq
@@ -26,12 +22,12 @@ jobs:
run: sudo apt-get install -y curl
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v2
with:
go-version: "1.20.x"
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v2
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -39,7 +35,7 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v2
- name: cp setup.env
run: cp infrastructure_files/tests/setup.env infrastructure_files/
@@ -57,9 +53,6 @@ jobs:
CI_NETBIRD_MGMT_IDP: "none"
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
- name: check values
working-directory: infrastructure_files
@@ -75,7 +68,6 @@ jobs:
CI_NETBIRD_AUTH_JWT_CERTS: https://example.eu.auth0.com/.well-known/jwks.json
CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token
CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code
CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT: https://example.eu.auth0.com/authorize
CI_NETBIRD_AUTH_REDIRECT_URI: "/peers"
CI_NETBIRD_TOKEN_SOURCE: "idToken"
CI_NETBIRD_AUTH_USER_ID_CLAIM: "email"
@@ -84,9 +76,6 @@ jobs:
CI_NETBIRD_MGMT_IDP: "none"
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_SIGNAL_PORT: 12345
CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
run: |
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
@@ -98,14 +87,11 @@ jobs:
grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073"
grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI
grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$'
grep $CI_NETBIRD_SIGNAL_PORT docker-compose.yml | grep ':80'
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
grep -A 3 DeviceAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
grep -A 3 DeviceAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
grep Engine management.json | grep "$CI_NETBIRD_STORE_CONFIG_ENGINE"
grep IdpSignKeyRefreshEnabled management.json | grep "$CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH"
grep -A 1 ProviderConfig management.json | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
grep Scope management.json | grep "$CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE"
grep UseIDToken management.json | grep false
grep -A 1 IdpManagerConfig management.json | grep ManagerType | grep $CI_NETBIRD_MGMT_IDP
grep -A 3 IdpManagerConfig management.json | grep -A 1 ClientConfig | grep Issuer | grep $CI_NETBIRD_AUTH_AUTHORITY
@@ -113,34 +99,6 @@ jobs:
grep -A 5 IdpManagerConfig management.json | grep -A 3 ClientConfig | grep ClientID | grep $CI_NETBIRD_IDP_MGMT_CLIENT_ID
grep -A 6 IdpManagerConfig management.json | grep -A 4 ClientConfig | grep ClientSecret | grep $CI_NETBIRD_IDP_MGMT_CLIENT_SECRET
grep -A 7 IdpManagerConfig management.json | grep -A 5 ClientConfig | grep GrantType | grep client_credentials
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_AUDIENCE
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep ClientID | grep $CI_NETBIRD_AUTH_CLIENT_ID
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep ClientSecret | grep $CI_NETBIRD_AUTH_CLIENT_SECRET
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep AuthorizationEndpoint | grep $CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
- name: Install modules
run: go mod tidy
- name: Build management binary
working-directory: management
run: CGO_ENABLED=1 go build -o netbird-mgmt main.go
- name: Build management docker image
working-directory: management
run: |
docker build -t netbirdio/management:latest .
- name: Build signal binary
working-directory: signal
run: CGO_ENABLED=0 go build -o netbird-signal main.go
- name: Build signal docker image
working-directory: signal
run: |
docker build -t netbirdio/signal:latest .
- name: run docker compose up
working-directory: infrastructure_files
@@ -152,31 +110,6 @@ jobs:
- name: test running containers
run: |
count=$(docker compose ps --format json | jq '. | select(.Name | contains("infrastructure_files")) | .State' | grep -c running)
count=$(docker compose ps --format json | jq '.[] | select(.Project | contains("infrastructure_files")) | .State' | grep -c running)
test $count -eq 4
working-directory: infrastructure_files
test-getting-started-script:
runs-on: ubuntu-latest
steps:
- name: Install jq
run: sudo apt-get install -y jq
- name: Checkout code
uses: actions/checkout@v3
- name: run script
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
- name: test Caddy file gen
run: test -f Caddyfile
- name: test docker-compose file gen
run: test -f docker-compose.yml
- name: test management.json file gen
run: test -f management.json
- name: test turnserver.conf file gen
run: test -f turnserver.conf
- name: test zitadel.env file gen
run: test -f zitadel.env
- name: test dashboard.env file gen
run: test -f dashboard.env

2
.gitignore vendored
View File

@@ -19,5 +19,3 @@ client/.distfiles/
infrastructure_files/setup.env
infrastructure_files/setup-*.env
.vscode
.DS_Store
*.db

View File

@@ -1,123 +0,0 @@
run:
# Timeout for analysis, e.g. 30s, 5m.
# Default: 1m
timeout: 6m
# This file contains only configs which differ from defaults.
# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
linters-settings:
errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default.
# Default: false
check-type-assertions: false
gosec:
includes:
- G101 # Look for hard coded credentials
#- G102 # Bind to all interfaces
- G103 # Audit the use of unsafe block
- G104 # Audit errors not checked
- G106 # Audit the use of ssh.InsecureIgnoreHostKey
#- G107 # Url provided to HTTP request as taint input
- G108 # Profiling endpoint automatically exposed on /debug/pprof
- G109 # Potential Integer overflow made by strconv.Atoi result conversion to int16/32
- G110 # Potential DoS vulnerability via decompression bomb
- G111 # Potential directory traversal
#- G112 # Potential slowloris attack
- G113 # Usage of Rat.SetString in math/big with an overflow (CVE-2022-23772)
#- G114 # Use of net/http serve function that has no support for setting timeouts
- G201 # SQL query construction using format string
- G202 # SQL query construction using string concatenation
- G203 # Use of unescaped data in HTML templates
#- G204 # Audit use of command execution
- G301 # Poor file permissions used when creating a directory
- G302 # Poor file permissions used with chmod
- G303 # Creating tempfile using a predictable path
- G304 # File path provided as taint input
- G305 # File traversal when extracting zip/tar archive
- G306 # Poor file permissions used when writing to a new file
- G307 # Poor file permissions used when creating a file with os.Create
#- G401 # Detect the usage of DES, RC4, MD5 or SHA1
#- G402 # Look for bad TLS connection settings
- G403 # Ensure minimum RSA key length of 2048 bits
#- G404 # Insecure random number source (rand)
#- G501 # Import blocklist: crypto/md5
- G502 # Import blocklist: crypto/des
- G503 # Import blocklist: crypto/rc4
- G504 # Import blocklist: net/http/cgi
#- G505 # Import blocklist: crypto/sha1
- G601 # Implicit memory aliasing of items from a range statement
- G602 # Slice access out of bounds
gocritic:
disabled-checks:
- commentFormatting
- captLocal
- deprecatedComment
govet:
# Enable all analyzers.
# Default: false
enable-all: false
enable:
- nilness
tenv:
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
# Default: false
all: true
linters:
disable-all: true
enable:
## enabled by default
- errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases
- gosimple # specializes in simplifying a code
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
- ineffassign # detects when assignments to existing variables are not used
- staticcheck # is a go vet on steroids, applying a ton of static analysis checks
- tenv # Tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17.
- typecheck # like the front-end of a Go compiler, parses and type-checks Go code
- unused # checks for unused constants, variables, functions and types
## disable by default but the have interesting results so lets add them
- bodyclose # checks whether HTTP response body is closed successfully
- dupword # dupword checks for duplicate words in the source code
- durationcheck # durationcheck checks for two durations multiplied together
- forbidigo # forbidigo forbids identifiers
- gocritic # provides diagnostics that check for bugs, performance and style issues
- gosec # inspects source code for security problems
- mirror # mirror reports wrong mirror patterns of bytes/strings usage
- misspell # misspess finds commonly misspelled English words in comments
- nilerr # finds the code that returns nil even if it checks that the error is not nil
- nilnil # checks that there is no simultaneous return of nil error and an invalid value
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
- thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
- wastedassign # wastedassign finds wasted assignment statements
issues:
# Maximum count of issues with the same text.
# Set to 0 to disable.
# Default: 3
max-same-issues: 5
exclude-rules:
# allow fmt
- path: management/cmd/root\.go
linters: forbidigo
- path: signal/cmd/root\.go
linters: forbidigo
- path: sharedsock/filter\.go
linters:
- unused
- path: client/firewall/iptables/rule\.go
linters:
- unused
- path: test\.go
linters:
- mirror
- gosec
- path: mock\.go
linters:
- nilnil

View File

@@ -377,13 +377,3 @@ uploads:
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
username: dev@wiretrustee.com
method: PUT
checksum:
extra_files:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
- glob: ./release_files/install.sh
release:
extra_files:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
- glob: ./release_files/install.sh

View File

@@ -11,8 +11,6 @@ builds:
- amd64
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
tags:
- legacy_appindicator
mod_timestamp: '{{ .CommitTimestamp }}'
- id: netbird-ui-windows
@@ -54,9 +52,12 @@ nfpms:
contents:
- src: client/ui/netbird.desktop
dst: /usr/share/applications/netbird.desktop
- src: client/ui/netbird-systemtray-default.png
- src: client/ui/disconnected.png
dst: /usr/share/pixmaps/netbird.png
dependencies:
- libayatana-appindicator3-1
- libgtk-3-dev
- libappindicator3-dev
- netbird
- maintainer: Netbird <dev@netbird.io>
@@ -71,9 +72,12 @@ nfpms:
contents:
- src: client/ui/netbird.desktop
dst: /usr/share/applications/netbird.desktop
- src: client/ui/netbird-systemtray-default.png
- src: client/ui/disconnected.png
dst: /usr/share/pixmaps/netbird.png
dependencies:
- libayatana-appindicator3-1
- libgtk-3-dev
- libappindicator3-dev
- netbird
uploads:
@@ -91,4 +95,4 @@ uploads:
mode: archive
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
username: dev@wiretrustee.com
method: PUT
method: PUT

View File

@@ -23,6 +23,7 @@ If you haven't already, join our slack workspace [here](https://join.slack.com/t
- [Test suite](#test-suite)
- [Checklist before submitting a PR](#checklist-before-submitting-a-pr)
- [Other project repositories](#other-project-repositories)
- [Checklist before submitting a new node](#checklist-before-submitting-a-new-node)
- [Contributor License Agreement](#contributor-license-agreement)
## Code of conduct
@@ -69,7 +70,7 @@ dependencies are installed. Here is a short guide on how that can be done.
### Requirements
#### Go 1.21
#### Go 1.19
Follow the installation guide from https://go.dev/
@@ -138,14 +139,15 @@ checked out and set up:
### Build and start
#### Client
> Windows clients have a Wireguard driver requirement. We provide a bash script that can be executed in WLS 2 with docker support [wireguard_nt.sh](/client/wireguard_nt.sh).
To start NetBird, execute:
```
cd client
CGO_ENABLED=0 go build .
# bash wireguard_nt.sh # if windows
go build .
```
> Windows clients have a Wireguard driver requirement. You can download the wintun driver from https://www.wintun.net/builds/wintun-0.14.1.zip, after decompressing, you can copy the file `windtun\bin\ARCH\wintun.dll` to the same path as your binary file or to `C:\Windows\System32\wintun.dll`.
To start NetBird the client in the foreground:
```
@@ -183,42 +185,6 @@ To start NetBird the management service:
./management management --log-level debug --log-file console --config ./management.json
```
#### Windows Netbird Installer
Create dist directory
```shell
mkdir -p dist/netbird_windows_amd64
```
UI client
```shell
CC=x86_64-w64-mingw32-gcc CGO_ENABLED=1 GOOS=windows GOARCH=amd64 go build -o netbird-ui.exe -ldflags "-s -w -H windowsgui" ./client/ui
mv netbird-ui.exe ./dist/netbird_windows_amd64/
```
Client
```shell
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o netbird.exe ./client/
mv netbird.exe ./dist/netbird_windows_amd64/
```
> Windows clients have a Wireguard driver requirement. You can download the wintun driver from https://www.wintun.net/builds/wintun-0.14.1.zip, after decompressing, you can copy the file `windtun\bin\ARCH\wintun.dll` to `./dist/netbird_windows_amd64/`.
NSIS compiler
- [Windows-nsis]( https://nsis.sourceforge.io/Download)
- [MacOS-makensis](https://formulae.brew.sh/formula/makensis#default)
- [Linux-makensis](https://manpages.ubuntu.com/manpages/trusty/man1/makensis.1.html)
NSIS Plugins. Download and move them to the NSIS plugins folder.
- [EnVar](https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip)
- [ShellExecAsUser](https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z)
Windows Installer
```shell
export APPVER=0.0.0.1
makensis -V4 client/installer.nsis
```
The installer `netbird-installer.exe` will be created in root directory.
### Test suite
The tests can be started via:
@@ -249,4 +215,4 @@ NetBird project is composed of 3 main repositories:
That we do not have any potential problems later it is sadly necessary to sign a [Contributor License Agreement](CONTRIBUTOR_LICENSE_AGREEMENT.md). That can be done literally with the push of a button.
A bot will automatically comment on the pull request once it got opened asking for the agreement to be signed. Before it did not get signed it is sadly not possible to merge it in.
A bot will automatically comment on the pull request once it got opened asking for the agreement to be signed. Before it did not get signed it is sadly not possible to merge it in.

View File

@@ -1,6 +1,6 @@
<p align="center">
<strong>:hatching_chick: New Release! Self-hosting in under 5 min.</strong>
<a href="https://github.com/netbirdio/netbird#quickstart-with-self-hosted-netbird">
<strong>:hatching_chick: New Release! Peer expiration.</strong>
<a href="https://github.com/netbirdio/netbird/releases">
Learn more
</a>
</p>
@@ -24,7 +24,7 @@
<p align="center">
<strong>
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
Start using NetBird at <a href="https://app.netbird.io/">app.netbird.io</a>
<br/>
See <a href="https://netbird.io/docs/">Documentation</a>
<br/>
@@ -36,62 +36,47 @@
<br>
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
**NetBird is an open-source VPN management platform built on top of WireGuard® making it easy to create secure private networks for your organization or home.**
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
It requires zero configuration effort leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
**Secure.** NetBird enables secure remote access by applying granular access policies, while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
NetBird uses [NAT traversal techniques](https://en.wikipedia.org/wiki/Interactive_Connectivity_Establishment) to automatically create an overlay peer-to-peer network connecting machines regardless of location (home, office, data center, container, cloud, or edge environments), unifying virtual private network management experience.
**Key features:**
- \[x] Automatic IP allocation and network management with a Web UI ([separate repo](https://github.com/netbirdio/dashboard))
- \[x] Automatic WireGuard peer (machine) discovery and configuration.
- \[x] Encrypted peer-to-peer connections without a central VPN gateway.
- \[x] Connection relay fallback in case a peer-to-peer connection is not possible.
- \[x] Desktop client applications for Linux, MacOS, and Windows (systray).
- \[x] Multiuser support - sharing network between multiple users.
- \[x] SSO and MFA support.
- \[x] Multicloud and hybrid-cloud support.
- \[x] Kernel WireGuard usage when possible.
- \[x] Access Controls - groups & rules.
- \[x] Remote SSH access without managing SSH keys.
- \[x] Network Routes.
- \[x] Private DNS.
- \[x] Network Activity Monitoring.
**Coming soon:**
- \[ ] Mobile clients.
### Secure peer-to-peer VPN with SSO and MFA in minutes
https://user-images.githubusercontent.com/700848/197345890-2e2cded5-7b7a-436f-a444-94e80dd24f46.mov
### Key features
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
| Connectivity | Management | Automation | Platforms |
|-------------------------------------------------------------------|--------------------------------------------------------------------------|----------------------------------------------------------------------------|---------------------------------------|
| <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
| <ul><li> - \[x] Peer-to-peer encryption </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | | <ul><li> - \[ ] iOS </ul></li> |
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | | <ul><li> - \[x] Docker </ul></li> |
| | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
| | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | | |
| | <ul><li> - \[x] SSH access management </ul></li> | | |
### Start using NetBird
- Hosted version: [https://app.netbird.io/](https://app.netbird.io/).
- See our documentation for [Quickstart Guide](https://netbird.io/docs/getting-started/quickstart).
- If you are looking to self-host NetBird, check our [Self-Hosting Guide](https://netbird.io/docs/getting-started/self-hosting).
- Step-by-step [Installation Guide](https://netbird.io/docs/getting-started/installation) for different platforms.
- Web UI [repository](https://github.com/netbirdio/dashboard).
- 5 min [demo video](https://youtu.be/Tu9tPsUWaY0) on YouTube.
### Quickstart with NetBird Cloud
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
- Follow the steps to sign-up with Google, Microsoft, GitHub or your email address.
- Check NetBird [admin UI](https://app.netbird.io/).
- Add more machines.
### Quickstart with self-hosted NetBird
> This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM.
Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IDPs.
**Infrastructure requirements:**
- A Linux VM with at least **1CPU** and **2GB** of memory.
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP ports: **3478**, **49152-65535**.
- **Public domain** name pointing to the VM.
**Software requirements:**
- Docker installed on the VM with the docker compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
- [curl](https://curl.se/) installed.
Usually available in the official repositories and can be installed with `sudo apt install curl` or `sudo yum install curl`
**Steps**
- Download and run the installation script:
```bash
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started-with-zitadel.sh | bash
```
- Once finished, you can manage the resources via `docker-compose`
### A bit on NetBird internals
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
@@ -103,18 +88,18 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
<p float="left" align="middle">
<img src="https://docs.netbird.io/docs-static/img/architecture/high-level-dia.png" width="700"/>
<img src="https://netbird.io/docs/img/architecture/high-level-dia.png" width="700"/>
</p>
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
See a complete [architecture overview](https://netbird.io/docs/overview/architecture) for details.
### Roadmap
- [Public Roadmap](https://github.com/netbirdio/netbird/projects/2)
### Community projects
- [NetBird on OpenWRT](https://github.com/messense/openwrt-netbird)
- [NetBird installer script](https://github.com/physk/netbird-installer)
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
### Support acknowledgement
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by The Federal Ministry of Education and Research of The Federal Republic of Germany. Together with [CISPA Helmholtz Center for Information Security](https://cispa.de/en) NetBird brings the security best practices and simplicity to private networking.
@@ -122,7 +107,7 @@ In November 2022, NetBird joined the [StartUpSecure program](https://www.forschu
![CISPA_Logo_BLACK_EN_RZ_RGB (1)](https://user-images.githubusercontent.com/700848/203091324-c6d311a0-22b5-4b05-a288-91cbc6cdcc46.png)
### Testimonials
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution).
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), and [Coturn](https://github.com/coturn/coturn). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution).
### Legal
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.

View File

@@ -18,9 +18,10 @@ func Encode(num uint32) string {
}
var encoded strings.Builder
remainder := uint32(0)
for num > 0 {
remainder := num % base
remainder = num % base
encoded.WriteByte(alphabet[remainder])
num /= base
}

View File

@@ -1,5 +1,7 @@
FROM alpine:3
RUN apk add --no-cache ca-certificates iptables ip6tables
FROM gcr.io/distroless/base:debug
ENV NB_FOREGROUND_MODE=true
ENV PATH=/sbin:/usr/sbin:/bin:/usr/bin:/busybox
SHELL ["/busybox/sh","-c"]
RUN sed -i -E 's/(^root:.+)\/sbin\/nologin/\1\/busybox\/sh/g' /etc/passwd
ENTRYPOINT [ "/go/bin/netbird","up"]
COPY netbird /go/bin/netbird

View File

@@ -7,9 +7,8 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
@@ -31,14 +30,9 @@ type IFaceDiscover interface {
stdnet.ExternalIFaceDiscover
}
// NetworkChangeListener export internal NetworkChangeListener for mobile
type NetworkChangeListener interface {
listener.NetworkChangeListener
}
// DnsReadyListener export internal dns ReadyListener for mobile
type DnsReadyListener interface {
dns.ReadyListener
// RouteListener export internal RouteListener for mobile
type RouteListener interface {
routemanager.RouteListener
}
func init() {
@@ -47,31 +41,31 @@ func init() {
// Client struct manage the life circle of background service
type Client struct {
cfgFile string
tunAdapter iface.TunAdapter
iFaceDiscover IFaceDiscover
recorder *peer.Status
ctxCancel context.CancelFunc
ctxCancelLock *sync.Mutex
deviceName string
networkChangeListener listener.NetworkChangeListener
cfgFile string
tunAdapter iface.TunAdapter
iFaceDiscover IFaceDiscover
recorder *peer.Status
ctxCancel context.CancelFunc
ctxCancelLock *sync.Mutex
deviceName string
routeListener routemanager.RouteListener
}
// NewClient instantiate a new Client
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, routeListener RouteListener) *Client {
return &Client{
cfgFile: cfgFile,
deviceName: deviceName,
tunAdapter: tunAdapter,
iFaceDiscover: iFaceDiscover,
recorder: peer.NewRecorder(""),
ctxCancelLock: &sync.Mutex{},
networkChangeListener: networkChangeListener,
cfgFile: cfgFile,
deviceName: deviceName,
tunAdapter: tunAdapter,
iFaceDiscover: iFaceDiscover,
recorder: peer.NewRecorder(""),
ctxCancelLock: &sync.Mutex{},
routeListener: routeListener,
}
}
// Run start the internal client. It is a blocker function
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
func (c *Client) Run(urlOpener URLOpener) error {
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
ConfigPath: c.cfgFile,
})
@@ -96,31 +90,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
// In this case make no sense handle registration steps.
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error {
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
ConfigPath: c.cfgFile,
})
if err != nil {
return err
}
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
var ctx context.Context
//nolint
ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName)
c.ctxCancelLock.Lock()
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
defer c.ctxCancel()
c.ctxCancelLock.Unlock()
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener)
}
// Stop the internal client and free the resources
@@ -156,17 +126,6 @@ func (c *Client) PeersList() *PeerInfoArray {
return &PeerInfoArray{items: peerInfos}
}
// OnUpdatedHostDNS update the DNS servers addresses for root zones
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
dnsServer, err := dns.GetServerDns()
if err != nil {
return err
}
dnsServer.OnUpdatedHostDNSServer(list.items)
return nil
}
// SetConnectionListener set the network connection listener
func (c *Client) SetConnectionListener(listener ConnectionListener) {
c.recorder.SetConnectionListener(listener)

View File

@@ -1,26 +0,0 @@
package android
import "fmt"
// DNSList is a wrapper of []string
type DNSList struct {
items []string
}
// Add new DNS address to the collection
func (array *DNSList) Add(s string) {
array.items = append(array.items, s)
}
// Get return an element of the collection
func (array *DNSList) Get(i int) (string, error) {
if i >= len(array.items) || i < 0 {
return "", fmt.Errorf("out of range")
}
return array.items[i], nil
}
// Size return with the size of the collection
func (array *DNSList) Size() int {
return len(array.items)
}

View File

@@ -1,24 +0,0 @@
package android
import "testing"
func TestDNSList_Get(t *testing.T) {
l := DNSList{
items: make([]string, 1),
}
_, err := l.Get(0)
if err != nil {
t.Errorf("invalid error: %s", err)
}
_, err = l.Get(-1)
if err == nil {
t.Errorf("expected error but got nil")
}
_, err = l.Get(1)
if err == nil {
t.Errorf("expected error but got nil")
}
}

View File

@@ -6,14 +6,15 @@ import (
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/client/internal"
)
// SSOListener is async listener for mobile framework
@@ -84,21 +85,11 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err)
if !ok {
return err
}
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
supportsSSO = false
err = nil
}
return err
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
supportsSSO = false
err = nil
}
return err
})
@@ -192,23 +183,35 @@ func (a *Auth) login(urlOpener URLOpener) error {
return nil
}
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false)
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*internal.TokenInfo, error) {
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if err != nil {
return nil, err
s, ok := gstatus.FromError(err)
if ok && s.Code() == codes.NotFound {
return nil, fmt.Errorf("no SSO provider returned from management. " +
"If you are using hosting Netbird see documentation at " +
"https://github.com/netbirdio/netbird/tree/main/management for details")
} else if ok && s.Code() == codes.Unimplemented {
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
"please update your servver or use Setup Keys to login", a.config.ManagementURL)
} else {
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
}
}
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig)
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
if err != nil {
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
return nil, fmt.Errorf("getting a request device code failed: %v", err)
}
go urlOpener.Open(flowInfo.VerificationURIComplete)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
waitTimeout := time.Duration(flowInfo.ExpiresIn)
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout*time.Second)
defer cancel()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
tokenInfo, err := hostedClient.WaitToken(waitCTX, flowInfo)
if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
}

View File

@@ -57,11 +57,11 @@ func TestPreferences_ReadUncommitedValues(t *testing.T) {
p.SetManagementURL(exampleString)
resp, err = p.GetManagementURL()
if err != nil {
t.Fatalf("failed to read management url: %s", err)
t.Fatalf("failed to read managmenet url: %s", err)
}
if resp != exampleString {
t.Errorf("unexpected management url: %s", resp)
t.Errorf("unexpected managemenet url: %s", resp)
}
p.SetPreSharedKey(exampleString)
@@ -102,11 +102,11 @@ func TestPreferences_Commit(t *testing.T) {
resp, err = p.GetManagementURL()
if err != nil {
t.Fatalf("failed to read management url: %s", err)
t.Fatalf("failed to read managmenet url: %s", err)
}
if resp != exampleURL {
t.Errorf("unexpected management url: %s", resp)
t.Errorf("unexpected managemenet url: %s", resp)
}
resp, err = p.GetPreSharedKey()

View File

@@ -3,20 +3,20 @@ package cmd
import (
"context"
"fmt"
"os"
"strings"
"time"
"github.com/skratchdot/open-golang/open"
"github.com/spf13/cobra"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/util"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/util"
)
var loginCmd = &cobra.Command{
@@ -81,11 +81,9 @@ var loginCmd = &cobra.Command{
client := proto.NewDaemonServiceClient(conn)
loginRequest := proto.LoginRequest{
SetupKey: setupKey,
PreSharedKey: preSharedKey,
ManagementUrl: managementURL,
IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName,
SetupKey: setupKey,
PreSharedKey: preSharedKey,
ManagementUrl: managementURL,
}
var loginErr error
@@ -115,7 +113,7 @@ var loginCmd = &cobra.Command{
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
if err != nil {
return fmt.Errorf("waiting sso login failed with: %v", err)
}
@@ -165,24 +163,40 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
return nil
}
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop())
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*internal.TokenInfo, error) {
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
return nil, err
s, ok := gstatus.FromError(err)
if ok && s.Code() == codes.NotFound {
return nil, fmt.Errorf("no SSO provider returned from management. " +
"If you are using hosting Netbird see documentation at " +
"https://github.com/netbirdio/netbird/tree/main/management for details")
} else if ok && s.Code() == codes.Unimplemented {
mgmtURL := managementURL
if mgmtURL == "" {
mgmtURL = internal.DefaultManagementURL
}
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
"please update your servver or use Setup Keys to login", mgmtURL)
} else {
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
}
}
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig)
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
if err != nil {
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
return nil, fmt.Errorf("getting a request device code failed: %v", err)
}
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
waitTimeout := time.Duration(flowInfo.ExpiresIn)
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout*time.Second)
defer c()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
tokenInfo, err := hostedClient.WaitToken(waitCTX, flowInfo)
if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
}
@@ -192,21 +206,15 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
var codeMsg string
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
if !strings.Contains(verificationURIComplete, userCode) {
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
}
cmd.Println("Please do the SSO login in your browser. \n" +
err := open.Run(verificationURIComplete)
cmd.Printf("Please do the SSO login in your browser. \n" +
"If your browser didn't open automatically, use this URL to log in:\n\n" +
verificationURIComplete + " " + codeMsg)
cmd.Println("")
if err := open.Run(verificationURIComplete); err != nil {
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
" " + verificationURIComplete + " " + codeMsg + " \n\n")
if err != nil {
cmd.Printf("Alternatively, you may want to use a setup key, see:\n\n https://www.netbird.io/docs/overview/setup-keys\n")
}
}
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
func isLinuxRunningDesktop() bool {
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
}

View File

@@ -92,7 +92,7 @@ func init() {
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultAdminURL))
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout")
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the the log will be output to stdout")
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
rootCmd.PersistentFlags().StringVar(&preSharedKey, "preshared-key", "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")

View File

@@ -73,8 +73,7 @@ var sshCmd = &cobra.Command{
go func() {
// blocking
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
log.Debug(err)
os.Exit(1)
log.Print(err)
}
cancel()
}()
@@ -93,10 +92,12 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
if err != nil {
cmd.Printf("Error: %v\n", err)
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
"You can verify the connection by running:\n\n" +
" netbird status\n\n")
return err
cmd.Printf("Couldn't connect. " +
"You might be disconnected from the NetBird network, or the NetBird agent isn't running.\n" +
"Run the status command: \n\n" +
" netbird status\n\n" +
"It might also be that the SSH server is disabled on the agent you are trying to connect to.\n")
return nil
}
go func() {
<-ctx.Done()

View File

@@ -109,9 +109,9 @@ func statusFunc(cmd *cobra.Command, args []string) error {
ctx := internal.CtxInitState(context.Background())
resp, err := getStatus(ctx, cmd)
resp, _ := getStatus(ctx, cmd)
if err != nil {
return err
return nil
}
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
@@ -120,7 +120,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
" netbird up \n\n"+
"If you are running a self-hosted version and no SSO provider has been configured in your Management Server,\n"+
"you can use a setup-key:\n\n netbird up --management-url <YOUR_MANAGEMENT_URL> --setup-key <YOUR_SETUP_KEY>\n\n"+
"More info: https://docs.netbird.io/how-to/register-machines-using-setup-keys\n\n",
"More info: https://www.netbird.io/docs/overview/setup-keys\n\n",
resp.GetStatus(),
)
return nil
@@ -133,7 +133,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
outputInformationHolder := convertToStatusOutputOverview(resp)
var statusOutputString string
statusOutputString := ""
switch {
case detailFlag:
statusOutputString = parseToFullDetailSummary(outputInformationHolder)
@@ -234,7 +234,7 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
continue
}
if isPeerConnected {
peersConnected++
peersConnected = peersConnected + 1
localICE = pbPeerState.GetLocalIceCandidateType()
remoteICE = pbPeerState.GetRemoteIceCandidateType()
@@ -407,7 +407,7 @@ func parsePeers(peers peersStateOutput) string {
peerState.LastStatusUpdate.Format("2006-01-02 15:04:05"),
)
peersString += peerString
peersString = peersString + peerString
}
return peersString
}

View File

@@ -22,7 +22,6 @@ import (
)
func startTestingServices(t *testing.T) string {
t.Helper()
config := &mgmt.Config{}
_, err := util.ReadJson("../testdata/management.json", config)
if err != nil {
@@ -45,7 +44,6 @@ func startTestingServices(t *testing.T) string {
}
func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
t.Helper()
lis, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
@@ -62,29 +60,28 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
}
func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
t.Helper()
lis, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
s := grpc.NewServer()
store, err := mgmt.NewStoreFromJson(config.Datadir, nil)
store, err := mgmt.NewFileStore(config.Datadir, nil)
if err != nil {
t.Fatal(err)
}
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
peersUpdateManager := mgmt.NewPeersUpdateManager()
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, nil
}
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "",
eventStore, false)
eventStore)
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
if err != nil {
t.Fatal(err)
}
@@ -101,7 +98,6 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
func startClientDaemon(
t *testing.T, ctx context.Context, managementURL, configPath string,
) (*grpc.Server, net.Listener) {
t.Helper()
lis, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)

View File

@@ -104,7 +104,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel)
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil, nil, nil)
}
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
@@ -123,7 +123,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
defer func() {
err := conn.Close()
if err != nil {
log.Warnf("failed closing daemon gRPC client connection %v", err)
log.Warnf("failed closing dameon gRPC client connection %v", err)
return
}
}()
@@ -141,15 +141,13 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
}
loginRequest := proto.LoginRequest{
SetupKey: setupKey,
PreSharedKey: preSharedKey,
ManagementUrl: managementURL,
AdminURL: adminURL,
NatExternalIPs: natExternalIPs,
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
CustomDNSAddress: customDNSAddressConverted,
IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName,
SetupKey: setupKey,
PreSharedKey: preSharedKey,
ManagementUrl: managementURL,
AdminURL: adminURL,
NatExternalIPs: natExternalIPs,
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
CustomDNSAddress: customDNSAddressConverted,
}
var loginErr error
@@ -180,7 +178,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
if err != nil {
return fmt.Errorf("waiting sso login failed with: %v", err)
}
@@ -201,11 +199,11 @@ func validateNATExternalIPs(list []string) error {
subElements := strings.Split(element, "/")
if len(subElements) > 2 {
return fmt.Errorf("%s is not a valid input for %s. it should be formatted as \"String\" or \"String/String\"", element, externalIPMapFlag)
return fmt.Errorf("%s is not a valid input for %s. it should be formated as \"String\" or \"String/String\"", element, externalIPMapFlag)
}
if len(subElements) == 1 && !isValidIP(subElements[0]) {
return fmt.Errorf("%s is not a valid input for %s. it should be formatted as \"IP\" or \"IP/IP\", or \"IP/Interface Name\"", element, externalIPMapFlag)
return fmt.Errorf("%s is not a valid input for %s. it should be formated as \"IP\" or \"IP/IP\", or \"IP/Interface Name\"", element, externalIPMapFlag)
}
last := 0
@@ -260,7 +258,7 @@ func parseCustomDNSAddress(modified bool) ([]byte, error) {
var parsed []byte
if modified {
if !isValidAddrPort(customDNSAddress) {
return nil, fmt.Errorf("%s is invalid, it should be formatted as IP:Port string or as an empty string like \"\"", customDNSAddress)
return nil, fmt.Errorf("%s is invalid, it should be formated as IP:Port string or as an empty string like \"\"", customDNSAddress)
}
if customDNSAddress == "" && logFile != "console" {
parsed = []byte("empty")

View File

@@ -40,9 +40,6 @@ const (
// It declares methods which handle actions required by the
// Netbird client for ACL and routing functionality
type Manager interface {
// AllowNetbird allows netbird interface traffic
AllowNetbird() error
// AddFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
@@ -54,7 +51,6 @@ type Manager interface {
dPort *Port,
direction RuleDirection,
action Action,
ipsetName string,
comment string,
) (Rule, error)
@@ -64,8 +60,5 @@ type Manager interface {
// Reset firewall to the default state
Reset() error
// Flush the changes to firewall controller
Flush() error
// TODO: migrate routemanager firewal actions to this interface
}

View File

@@ -8,7 +8,6 @@ import (
"github.com/coreos/go-iptables/iptables"
"github.com/google/uuid"
"github.com/nadoo/ipset"
log "github.com/sirupsen/logrus"
fw "github.com/netbirdio/netbird/client/firewall"
@@ -36,53 +35,36 @@ type Manager struct {
inputDefaultRuleSpecs []string
outputDefaultRuleSpecs []string
wgIface iFaceMapper
rulesets map[string]ruleset
}
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
}
type ruleset struct {
rule *Rule
ips map[string]string
}
// Create iptables firewall manager
func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) {
func Create(wgIface iFaceMapper) (*Manager, error) {
m := &Manager{
wgIface: wgIface,
inputDefaultRuleSpecs: []string{
"-i", wgIface.Name(), "-j", ChainInputFilterName, "-s", wgIface.Address().String()},
outputDefaultRuleSpecs: []string{
"-o", wgIface.Name(), "-j", ChainOutputFilterName, "-d", wgIface.Address().String()},
rulesets: make(map[string]ruleset),
}
err := ipset.Init()
if err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
}
// init clients for booth ipv4 and ipv6
m.ipv4Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
}
m.ipv4Client = ipv4Client
if ipv6Supported {
m.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
log.Warnf("ip6tables is not installed in the system or not supported: %v. Access rules for this protocol won't be applied.", err)
}
}
if m.ipv4Client == nil && m.ipv6Client == nil {
return nil, fmt.Errorf("iptables is not installed in the system or not enough permissions to use it")
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
log.Errorf("ip6tables is not installed in the system or not supported: %v", err)
} else {
m.ipv6Client = ipv6Client
}
if err := m.Reset(); err != nil {
@@ -93,7 +75,7 @@ func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) {
// AddFiltering rule to the firewall
//
// Comment will be ignored because some system this feature is not supported
// If comment is empty rule ID is used as comment
func (m *Manager) AddFiltering(
ip net.IP,
protocol fw.Protocol,
@@ -101,7 +83,6 @@ func (m *Manager) AddFiltering(
dPort *fw.Port,
direction fw.RuleDirection,
action fw.Action,
ipsetName string,
comment string,
) (fw.Rule, error) {
m.mutex.Lock()
@@ -120,41 +101,22 @@ func (m *Manager) AddFiltering(
if sPort != nil && sPort.Values != nil {
sPortVal = strconv.Itoa(sPort.Values[0])
}
ipsetName = m.transformIPsetName(ipsetName, sPortVal, dPortVal)
ruleID := uuid.New().String()
if ipsetName != "" {
rs, rsExists := m.rulesets[ipsetName]
if !rsExists {
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %q before use it: %v", ipsetName, err)
}
if err := ipset.Create(ipsetName); err != nil {
return nil, fmt.Errorf("failed to create ipset: %w", err)
}
}
if err := ipset.Add(ipsetName, ip.String()); err != nil {
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
}
if rsExists {
// if ruleset already exists it means we already have the firewall rule
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
rs.ips[ip.String()] = ruleID
return &Rule{
ruleID: ruleID,
ipsetName: ipsetName,
ip: ip.String(),
dst: direction == fw.RuleDirectionOUT,
v6: ip.To4() == nil,
}, nil
}
// this is new ipset so we need to create firewall rule for it
if comment == "" {
comment = ruleID
}
specs := m.filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
specs := m.filterRuleSpecs(
"filter",
ip,
string(protocol),
sPortVal,
dPortVal,
direction,
action,
comment,
)
if direction == fw.RuleDirectionOUT {
ok, err := client.Exists("filter", ChainOutputFilterName, specs...)
@@ -182,24 +144,12 @@ func (m *Manager) AddFiltering(
}
}
rule := &Rule{
ruleID: ruleID,
specs: specs,
ipsetName: ipsetName,
ip: ip.String(),
dst: direction == fw.RuleDirectionOUT,
v6: ip.To4() == nil,
}
if ipsetName != "" {
// ipset name is defined and it means that this rule was created
// for it, need to associate it with ruleset
m.rulesets[ipsetName] = ruleset{
rule: rule,
ips: map[string]string{rule.ip: ruleID},
}
}
return rule, nil
return &Rule{
id: ruleID,
specs: specs,
dst: direction == fw.RuleDirectionOUT,
v6: ip.To4() == nil,
}, nil
}
// DeleteRule from the firewall by rule definition
@@ -220,31 +170,6 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
client = m.ipv6Client
}
if rs, ok := m.rulesets[r.ipsetName]; ok {
// delete IP from ruleset IPs list and ipset
if _, ok := rs.ips[r.ip]; ok {
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
return fmt.Errorf("failed to delete ip from ipset: %w", err)
}
delete(rs.ips, r.ip)
}
// if after delete, set still contains other IPs,
// no need to delete firewall rule and we should exit here
if len(rs.ips) != 0 {
return nil
}
// we delete last IP from the set, that means we need to delete
// set itself and associated firewall rule too
delete(m.rulesets, r.ipsetName)
if err := ipset.Destroy(r.ipsetName); err != nil {
log.Errorf("delete empty ipset: %v", err)
}
r = rs.rule
}
if r.dst {
return client.Delete("filter", ChainOutputFilterName, r.specs...)
}
@@ -268,41 +193,6 @@ func (m *Manager) Reset() error {
return nil
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
if m.wgIface.IsUserspaceBind() {
_, err := m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
fw.RuleDirectionIN,
fw.ActionAccept,
"",
"",
)
if err != nil {
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
}
_, err = m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
fw.RuleDirectionOUT,
fw.ActionAccept,
"",
"",
)
return err
}
return nil
}
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// reset firewall chain, clear it and drop it
func (m *Manager) reset(client *iptables.IPTables, table string) error {
ok, err := client.ChainExists(table, ChainInputFilterName)
@@ -343,22 +233,13 @@ func (m *Manager) reset(client *iptables.IPTables, table string) error {
return nil
}
for ipsetName := range m.rulesets {
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
}
if err := ipset.Destroy(ipsetName); err != nil {
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
}
delete(m.rulesets, ipsetName)
}
return nil
}
// filterRuleSpecs returns the specs of a filtering rule
func (m *Manager) filterRuleSpecs(
ip net.IP, protocol string, sPort, dPort string, direction fw.RuleDirection, action fw.Action, ipsetName string,
table string, ip net.IP, protocol string, sPort, dPort string,
direction fw.RuleDirection, action fw.Action, comment string,
) (specs []string) {
matchByIP := true
// don't use IP matching if IP is ip 0.0.0.0
@@ -368,19 +249,11 @@ func (m *Manager) filterRuleSpecs(
switch direction {
case fw.RuleDirectionIN:
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
} else {
specs = append(specs, "-s", ip.String())
}
specs = append(specs, "-s", ip.String())
}
case fw.RuleDirectionOUT:
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
} else {
specs = append(specs, "-d", ip.String())
}
specs = append(specs, "-d", ip.String())
}
}
if protocol != "all" {
@@ -392,7 +265,8 @@ func (m *Manager) filterRuleSpecs(
if dPort != "" {
specs = append(specs, "--dport", dPort)
}
return append(specs, "-j", m.actionToStr(action))
specs = append(specs, "-j", m.actionToStr(action))
return append(specs, "-m", "comment", "--comment", comment)
}
// rawClient returns corresponding iptables client for the given ip
@@ -427,7 +301,7 @@ func (m *Manager) client(ip net.IP) (*iptables.IPTables, error) {
return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err)
}
if err := client.Insert("filter", "INPUT", 1, m.inputDefaultRuleSpecs...); err != nil {
if err := client.AppendUnique("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil {
return nil, fmt.Errorf("failed to create input chain jump rule: %w", err)
}
@@ -461,18 +335,3 @@ func (m *Manager) actionToStr(action fw.Action) string {
}
return "DROP"
}
func (m *Manager) transformIPsetName(ipsetName string, sPort, dPort string) string {
switch {
case ipsetName == "":
return ""
case sPort != "" && dPort != "":
return ipsetName + "-sport-dport"
case sPort != "":
return ipsetName + "-sport"
case dPort != "":
return ipsetName + "-dport"
default:
return ipsetName
}
}

View File

@@ -33,8 +33,6 @@ func (i *iFaceMock) Address() iface.WGAddress {
panic("AddressFunc is not set")
}
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestIptablesManager(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)
@@ -55,15 +53,14 @@ func TestIptablesManager(t *testing.T) {
}
// just check on the local interface
manager, err := Create(mock, true)
manager, err := Create(mock)
require.NoError(t, err)
time.Sleep(time.Second)
defer func() {
err := manager.Reset()
require.NoError(t, err, "clear the manager state")
if err := manager.Reset(); err != nil {
t.Errorf("clear the manager state: %v", err)
}
time.Sleep(time.Second)
}()
@@ -71,7 +68,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("add first rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
@@ -84,31 +81,33 @@ func TestIptablesManager(t *testing.T) {
Values: []int{8043: 8046},
}
rule2, err = manager.AddFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTPS traffic from ports range")
require.NoError(t, err, "failed to add rule")
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...)
})
t.Run("delete first rule", func(t *testing.T) {
err := manager.DeleteRule(rule1)
require.NoError(t, err, "failed to delete rule")
if err := manager.DeleteRule(rule1); err != nil {
require.NoError(t, err, "failed to delete rule")
}
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, false, rule1.(*Rule).specs...)
})
t.Run("delete second rule", func(t *testing.T) {
err := manager.DeleteRule(rule2)
require.NoError(t, err, "failed to delete rule")
if err := manager.DeleteRule(rule2); err != nil {
require.NoError(t, err, "failed to delete rule")
}
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, false, rule2.(*Rule).specs...)
})
t.Run("reset check", func(t *testing.T) {
// add second rule
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []int{5353}}
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule")
err = manager.Reset()
@@ -123,90 +122,7 @@ func TestIptablesManager(t *testing.T) {
})
}
func TestIptablesManagerIPSet(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)
mock := &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// just check on the local interface
manager, err := Create(mock, true)
require.NoError(t, err)
time.Sleep(time.Second)
defer func() {
err := manager.Reset()
require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second)
}()
var rule1 fw.Rule
t.Run("add first rule with set", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddFiltering(
ip, "tcp", nil, port, fw.RuleDirectionOUT,
fw.ActionAccept, "default", "accept HTTP traffic",
)
require.NoError(t, err, "failed to add rule")
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
require.Equal(t, rule1.(*Rule).ipsetName, "default-dport", "ipset name must be set")
require.Equal(t, rule1.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
})
var rule2 fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
Values: []int{443},
}
rule2, err = manager.AddFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
"default", "accept HTTPS traffic from ports range",
)
require.NoError(t, err, "failed to add rule")
require.Equal(t, rule2.(*Rule).ipsetName, "default-sport", "ipset name must be set")
require.Equal(t, rule2.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
})
t.Run("delete first rule", func(t *testing.T) {
err := manager.DeleteRule(rule1)
require.NoError(t, err, "failed to delete rule")
require.NotContains(t, manager.rulesets, rule1.(*Rule).ruleID, "rule must be removed form the ruleset index")
})
t.Run("delete second rule", func(t *testing.T) {
err := manager.DeleteRule(rule2)
require.NoError(t, err, "failed to delete rule")
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
})
t.Run("reset check", func(t *testing.T) {
err = manager.Reset()
require.NoError(t, err, "failed to reset")
})
}
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) {
t.Helper()
exists, err := ipv4Client.Exists("filter", chainName, rulespec...)
require.NoError(t, err, "failed to check rule")
require.Falsef(t, !exists && mustExists, "rule '%v' does not exist", rulespec)
@@ -232,14 +148,14 @@ func TestIptablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface
manager, err := Create(mock, true)
manager, err := Create(mock)
require.NoError(t, err)
time.Sleep(time.Second)
defer func() {
err := manager.Reset()
require.NoError(t, err, "clear the manager state")
if err := manager.Reset(); err != nil {
t.Errorf("clear the manager state: %v", err)
}
time.Sleep(time.Second)
}()
@@ -251,9 +167,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
} else {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule")

View File

@@ -2,16 +2,13 @@ package iptables
// Rule to handle management of rules
type Rule struct {
ruleID string
ipsetName string
id string
specs []string
ip string
dst bool
v6 bool
}
// GetRuleID returns the rule id
func (r *Rule) GetRuleID() string {
return r.ruleID
return r.id
}

View File

@@ -6,14 +6,12 @@ import (
"fmt"
"net"
"net/netip"
"strconv"
"strings"
"sync"
"time"
"github.com/google/nftables"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"github.com/google/uuid"
"golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall"
@@ -29,18 +27,13 @@ const (
// FilterOutputChainName is the name of the chain that is used for filtering outgoing packets
FilterOutputChainName = "netbird-acl-output-filter"
AllowNetbirdInputRuleID = "allow Netbird incoming traffic"
)
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
// Manager of iptables firewall
type Manager struct {
mutex sync.Mutex
rConn *nftables.Conn
sConn *nftables.Conn
conn *nftables.Conn
tableIPv4 *nftables.Table
tableIPv6 *nftables.Table
@@ -50,10 +43,6 @@ type Manager struct {
filterInputChainIPv6 *nftables.Chain
filterOutputChainIPv6 *nftables.Chain
rulesetManager *rulesetManager
setRemovedIPs map[string]struct{}
setRemoved map[string]*nftables.Set
wgIface iFaceMapper
}
@@ -65,23 +54,8 @@ type iFaceMapper interface {
// Create nftables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) {
// sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation)
// and is permanent. Using same connection for booth type of operations
// overloads netlink with high amount of rules ( > 10000)
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return nil, err
}
m := &Manager{
rConn: &nftables.Conn{},
sConn: sConn,
rulesetManager: newRuleManager(),
setRemovedIPs: map[string]struct{}{},
setRemoved: map[string]*nftables.Set{},
conn: &nftables.Conn{},
wgIface: wgIface,
}
@@ -103,7 +77,6 @@ func (m *Manager) AddFiltering(
dPort *fw.Port,
direction fw.RuleDirection,
action fw.Action,
ipsetName string,
comment string,
) (fw.Rule, error) {
m.mutex.Lock()
@@ -111,7 +84,6 @@ func (m *Manager) AddFiltering(
var (
err error
ipset *nftables.Set
table *nftables.Table
chain *nftables.Chain
)
@@ -135,46 +107,6 @@ func (m *Manager) AddFiltering(
return nil, err
}
rawIP := ip.To4()
if rawIP == nil {
rawIP = ip.To16()
}
rulesetID := m.getRulesetID(ip, proto, sPort, dPort, direction, action, ipsetName)
if ipsetName != "" {
// if we already have set with given name, just add ip to the set
// and return rule with new ID in other case let's create rule
// with fresh created set and set element
var isSetNew bool
ipset, err = m.rConn.GetSetByName(table, ipsetName)
if err != nil {
if ipset, err = m.createSet(table, rawIP, ipsetName); err != nil {
return nil, fmt.Errorf("get set name: %v", err)
}
isSetNew = true
}
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
return nil, fmt.Errorf("add set element for the first time: %v", err)
}
if err := m.sConn.Flush(); err != nil {
return nil, fmt.Errorf("flush add elements: %v", err)
}
if !isSetNew {
// if we already have nftables rules with set for given direction
// just add new rule to the ruleset and return new fw.Rule object
if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok {
return m.rulesetManager.addRule(ruleset, rawIP)
}
// if ipset exists but it is not linked to rule for given direction
// create new rule for direction and bind ipset to it later
}
}
ifaceKey := expr.MetaKeyIIFNAME
if direction == fw.RuleDirectionOUT {
ifaceKey = expr.MetaKeyOIFNAME
@@ -214,47 +146,39 @@ func (m *Manager) AddFiltering(
})
}
// check if rawIP contains zeroed IPv4 0.0.0.0 or same IPv6 value
// in that case not add IP match expression into the rule definition
if !bytes.HasPrefix(anyIP, rawIP) {
// don't use IP matching if IP is ip 0.0.0.0
if s := ip.String(); s != "0.0.0.0" && s != "::" {
// source address position
addrLen := uint32(len(rawIP))
addrOffset := uint32(12)
if addrLen == 16 {
addrOffset = 8
var adrLen, adrOffset uint32
if ip.To4() == nil {
adrLen = 16
adrOffset = 8
} else {
adrLen = 4
adrOffset = 12
}
// change to destination address position if need
if direction == fw.RuleDirectionOUT {
addrOffset += addrLen
adrOffset += adrLen
}
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: addrOffset,
Len: addrLen,
Offset: adrOffset,
Len: adrLen,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: add.AsSlice(),
},
)
// add individual IP for match if no ipset defined
if ipset == nil {
expressions = append(expressions,
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rawIP,
},
)
} else {
expressions = append(expressions,
&expr.Lookup{
SourceRegister: 1,
SetName: ipsetName,
SetID: ipset.ID,
},
)
}
}
if sPort != nil && len(sPort.Values) != 0 {
@@ -295,76 +219,39 @@ func (m *Manager) AddFiltering(
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
}
userData := []byte(strings.Join([]string{rulesetID, comment}, " "))
id := uuid.New().String()
userData := []byte(strings.Join([]string{id, comment}, " "))
rule := m.rConn.InsertRule(&nftables.Rule{
_ = m.conn.InsertRule(&nftables.Rule{
Table: table,
Chain: chain,
Position: 0,
Exprs: expressions,
UserData: userData,
})
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf("flush insert rule: %v", err)
if err := m.conn.Flush(); err != nil {
return nil, err
}
ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset)
return m.rulesetManager.addRule(ruleset, rawIP)
}
// getRulesetID returns ruleset ID based on given parameters
func (m *Manager) getRulesetID(
ip net.IP,
proto fw.Protocol,
sPort *fw.Port,
dPort *fw.Port,
direction fw.RuleDirection,
action fw.Action,
ipsetName string,
) string {
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
if sPort != nil {
rulesetID += sPort.String()
}
rulesetID += ":"
if dPort != nil {
rulesetID += dPort.String()
}
rulesetID += ":"
rulesetID += strconv.Itoa(int(action))
if ipsetName == "" {
return "ip:" + ip.String() + rulesetID
}
return "set:" + ipsetName + rulesetID
}
// createSet in given table by name
func (m *Manager) createSet(
table *nftables.Table,
rawIP []byte,
name string,
) (*nftables.Set, error) {
keyType := nftables.TypeIPAddr
if len(rawIP) == 16 {
keyType = nftables.TypeIP6Addr
}
// else we create new ipset and continue creating rule
ipset := &nftables.Set{
Name: name,
Table: table,
Dynamic: true,
KeyType: keyType,
list, err := m.conn.GetRules(table, chain)
if err != nil {
return nil, err
}
if err := m.rConn.AddSet(ipset, nil); err != nil {
return nil, fmt.Errorf("create set: %v", err)
// Add the rule to the chain
rule := &Rule{id: id}
for _, r := range list {
if bytes.Equal(r.UserData, userData) {
rule.Rule = r
break
}
}
if rule.Rule == nil {
return nil, fmt.Errorf("rule not found")
}
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf("flush created set: %v", err)
}
return ipset, nil
return rule, nil
}
// chain returns the chain for the given IP address with specific settings
@@ -381,7 +268,7 @@ func (m *Manager) chain(
if c != nil {
return c, nil
}
return m.createChainIfNotExists(tf, FilterTableName, name, hook, priority, cType)
return m.createChainIfNotExists(tf, name, hook, priority, cType)
}
if ip.To4() != nil {
@@ -401,20 +288,13 @@ func (m *Manager) chain(
}
// table returns the table for the given family of the IP address
func (m *Manager) table(
family nftables.TableFamily, tableName string,
) (*nftables.Table, error) {
// we cache access to Netbird ACL table only
if tableName != FilterTableName {
return m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
}
func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
if family == nftables.TableFamilyIPv4 {
if m.tableIPv4 != nil {
return m.tableIPv4, nil
}
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4)
if err != nil {
return nil, err
}
@@ -426,7 +306,7 @@ func (m *Manager) table(
return m.tableIPv6, nil
}
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6, tableName)
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6)
if err != nil {
return nil, err
}
@@ -434,41 +314,34 @@ func (m *Manager) table(
return m.tableIPv6, nil
}
func (m *Manager) createTableIfNotExists(
family nftables.TableFamily, tableName string,
) (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(family)
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) {
tables, err := m.conn.ListTablesOfFamily(family)
if err != nil {
return nil, fmt.Errorf("list of tables: %w", err)
}
for _, t := range tables {
if t.Name == tableName {
if t.Name == FilterTableName {
return t, nil
}
}
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
if err := m.rConn.Flush(); err != nil {
return nil, err
}
return table, nil
return m.conn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4}), nil
}
func (m *Manager) createChainIfNotExists(
family nftables.TableFamily,
tableName string,
name string,
hooknum nftables.ChainHook,
priority nftables.ChainPriority,
chainType nftables.ChainType,
) (*nftables.Chain, error) {
table, err := m.table(family, tableName)
table, err := m.table(family)
if err != nil {
return nil, err
}
chains, err := m.rConn.ListChainsOfTableFamily(family)
chains, err := m.conn.ListChainsOfTableFamily(family)
if err != nil {
return nil, fmt.Errorf("list of chains: %w", err)
}
@@ -489,7 +362,7 @@ func (m *Manager) createChainIfNotExists(
Policy: &polAccept,
}
chain = m.rConn.AddChain(chain)
chain = m.conn.AddChain(chain)
ifaceKey := expr.MetaKeyIIFNAME
shiftDSTAddr := 0
@@ -556,7 +429,7 @@ func (m *Manager) createChainIfNotExists(
)
}
_ = m.rConn.AddRule(&nftables.Rule{
_ = m.conn.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: expressions,
@@ -571,13 +444,12 @@ func (m *Manager) createChainIfNotExists(
},
&expr.Verdict{Kind: expr.VerdictDrop},
}
_ = m.rConn.AddRule(&nftables.Rule{
_ = m.conn.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: expressions,
})
if err := m.rConn.Flush(); err != nil {
if err := m.conn.Flush(); err != nil {
return nil, err
}
@@ -586,58 +458,16 @@ func (m *Manager) createChainIfNotExists(
// DeleteRule from the firewall by rule definition
func (m *Manager) DeleteRule(rule fw.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
nativeRule, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
}
if nativeRule.nftRule == nil {
return nil
}
if nativeRule.nftSet != nil {
// call twice of delete set element raises error
// so we need to check if element is already removed
key := fmt.Sprintf("%s:%v", nativeRule.nftSet.Name, nativeRule.ip)
if _, ok := m.setRemovedIPs[key]; !ok {
err := m.sConn.SetDeleteElements(nativeRule.nftSet, []nftables.SetElement{{Key: nativeRule.ip}})
if err != nil {
log.Errorf("delete elements for set %q: %v", nativeRule.nftSet.Name, err)
}
if err := m.sConn.Flush(); err != nil {
return err
}
m.setRemovedIPs[key] = struct{}{}
}
}
if m.rulesetManager.deleteRule(nativeRule) {
// deleteRule indicates that we still have IP in the ruleset
// it means we should not remove the nftables rule but need to update set
// so we prepare IP to be removed from set on the next flush call
return nil
}
// ruleset doesn't contain IP anymore (or contains only one), remove nft rule
if err := m.rConn.DelRule(nativeRule.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
if err := m.rConn.Flush(); err != nil {
if err := m.conn.DelRule(nativeRule.Rule); err != nil {
return err
}
nativeRule.nftRule = nil
if nativeRule.nftSet != nil {
if _, ok := m.setRemoved[nativeRule.nftSet.Name]; !ok {
m.setRemoved[nativeRule.nftSet.Name] = nativeRule.nftSet
}
nativeRule.nftSet = nil
}
return nil
return m.conn.Flush()
}
// Reset firewall to the default state
@@ -645,217 +475,27 @@ func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
chains, err := m.rConn.ListChains()
chains, err := m.conn.ListChains()
if err != nil {
return fmt.Errorf("list of chains: %w", err)
}
for _, c := range chains {
// delete Netbird allow input traffic rule if it exists
if c.Table.Name == "filter" && c.Name == "INPUT" {
rules, err := m.rConn.GetRules(c.Table, c)
if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err)
continue
}
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(AllowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
}
}
}
}
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
m.rConn.DelChain(c)
m.conn.DelChain(c)
}
}
tables, err := m.rConn.ListTables()
tables, err := m.conn.ListTables()
if err != nil {
return fmt.Errorf("list of tables: %w", err)
}
for _, t := range tables {
if t.Name == FilterTableName {
m.rConn.DelTable(t)
m.conn.DelTable(t)
}
}
return m.rConn.Flush()
}
// Flush rule/chain/set operations from the buffer
//
// Method also get all rules after flush and refreshes handle values in the rulesets
func (m *Manager) Flush() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.flushWithBackoff(); err != nil {
return err
}
// set must be removed after flush rule changes
// otherwise we will get error
for _, s := range m.setRemoved {
m.rConn.FlushSet(s)
m.rConn.DelSet(s)
}
if len(m.setRemoved) > 0 {
if err := m.flushWithBackoff(); err != nil {
return err
}
}
m.setRemovedIPs = map[string]struct{}{}
m.setRemoved = map[string]*nftables.Set{}
if err := m.refreshRuleHandles(m.tableIPv4, m.filterInputChainIPv4); err != nil {
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
}
if err := m.refreshRuleHandles(m.tableIPv4, m.filterOutputChainIPv4); err != nil {
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
}
if err := m.refreshRuleHandles(m.tableIPv6, m.filterInputChainIPv6); err != nil {
log.Errorf("failed to refresh rule handles IPv6 input chain: %v", err)
}
if err := m.refreshRuleHandles(m.tableIPv6, m.filterOutputChainIPv6); err != nil {
log.Errorf("failed to refresh rule handles IPv6 output chain: %v", err)
}
return nil
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
m.mutex.Lock()
defer m.mutex.Unlock()
tf := nftables.TableFamilyIPv4
if m.wgIface.Address().IP.To4() == nil {
tf = nftables.TableFamilyIPv6
}
chains, err := m.rConn.ListChainsOfTableFamily(tf)
if err != nil {
return fmt.Errorf("list of chains: %w", err)
}
var chain *nftables.Chain
for _, c := range chains {
if c.Table.Name == "filter" && c.Name == "INPUT" {
chain = c
break
}
}
if chain == nil {
log.Debugf("chain INPUT not found. Skipping add allow netbird rule")
return nil
}
rules, err := m.rConn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("failed to get rules for the INPUT chain: %v", err)
}
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
log.Debugf("allow netbird rule already exists: %v", rule)
return nil
}
m.applyAllowNetbirdRules(chain)
err = m.rConn.Flush()
if err != nil {
return fmt.Errorf("failed to flush allow input netbird rules: %v", err)
}
return nil
}
func (m *Manager) flushWithBackoff() (err error) {
backoff := 4
backoffTime := 1000 * time.Millisecond
for i := 0; ; i++ {
err = m.rConn.Flush()
if err != nil {
if !strings.Contains(err.Error(), "busy") {
return
}
log.Error("failed to flush nftables, retrying...")
if i == backoff-1 {
return err
}
time.Sleep(backoffTime)
backoffTime *= 2
continue
}
break
}
return
}
func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chain) error {
if table == nil || chain == nil {
return nil
}
list, err := m.rConn.GetRules(table, chain)
if err != nil {
return err
}
for _, rule := range list {
if len(rule.UserData) != 0 {
if err := m.rulesetManager.setNftRuleHandle(rule); err != nil {
log.Errorf("failed to set rule handle: %v", err)
}
}
}
return nil
}
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
rule := &nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: []byte(AllowNetbirdInputRuleID),
}
_ = m.rConn.InsertRule(rule)
}
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
ifName := ifname(m.wgIface.Name())
for _, rule := range existedRules {
if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" {
if len(rule.Exprs) < 4 {
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
continue
}
if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) {
continue
}
return rule
}
}
}
return nil
return m.conn.Flush()
}
func encodePort(port fw.Port) []byte {

View File

@@ -55,7 +55,7 @@ func TestNftablesManager(t *testing.T) {
// just check on the local interface
manager, err := Create(mock)
require.NoError(t, err)
time.Sleep(time.Second * 3)
time.Sleep(time.Second)
defer func() {
err = manager.Reset()
@@ -75,16 +75,11 @@ func TestNftablesManager(t *testing.T) {
fw.RuleDirectionIN,
fw.ActionDrop,
"",
"",
)
require.NoError(t, err, "failed to add rule")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
require.NoError(t, err, "failed to get rules")
// test expectations:
// 1) regular rule
// 2) "accept extra routed traffic rule" for the interface
@@ -140,15 +135,12 @@ func TestNftablesManager(t *testing.T) {
err = manager.DeleteRule(rule)
require.NoError(t, err, "failed to delete rule")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
require.NoError(t, err, "failed to get rules")
// test expectations:
// 1) "accept extra routed traffic rule" for the interface
// 2) "drop all rule" for the interface
require.Len(t, rules, 2, "expected 2 rules after deletion")
require.Len(t, rules, 2, "expected 2 rules after deleteion")
err = manager.Reset()
require.NoError(t, err, "failed to reset")
@@ -175,7 +167,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
// just check on the local interface
manager, err := Create(mock)
require.NoError(t, err)
time.Sleep(time.Second * 3)
time.Sleep(time.Second)
defer func() {
if err := manager.Reset(); err != nil {
@@ -189,18 +181,13 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
} else {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule")
if i%100 == 0 {
err = manager.Flush()
require.NoError(t, err, "failed to flush")
}
}
t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax))
})
}

View File

@@ -6,14 +6,11 @@ import (
// Rule to handle management of rules
type Rule struct {
nftRule *nftables.Rule
nftSet *nftables.Set
ruleID string
ip []byte
*nftables.Rule
id string
}
// GetRuleID returns the rule id
func (r *Rule) GetRuleID() string {
return r.ruleID
return r.id
}

View File

@@ -1,115 +0,0 @@
package nftables
import (
"bytes"
"fmt"
"github.com/google/nftables"
"github.com/rs/xid"
)
// nftRuleset links native firewall rule and ipset to ACL generated rules
type nftRuleset struct {
nftRule *nftables.Rule
nftSet *nftables.Set
issuedRules map[string]*Rule
rulesetID string
}
type rulesetManager struct {
rulesets map[string]*nftRuleset
nftSetName2rulesetID map[string]string
issuedRuleID2rulesetID map[string]string
}
func newRuleManager() *rulesetManager {
return &rulesetManager{
rulesets: map[string]*nftRuleset{},
nftSetName2rulesetID: map[string]string{},
issuedRuleID2rulesetID: map[string]string{},
}
}
func (r *rulesetManager) getRuleset(rulesetID string) (*nftRuleset, bool) {
ruleset, ok := r.rulesets[rulesetID]
return ruleset, ok
}
func (r *rulesetManager) createRuleset(
rulesetID string,
nftRule *nftables.Rule,
nftSet *nftables.Set,
) *nftRuleset {
ruleset := nftRuleset{
rulesetID: rulesetID,
nftRule: nftRule,
nftSet: nftSet,
issuedRules: map[string]*Rule{},
}
r.rulesets[ruleset.rulesetID] = &ruleset
if nftSet != nil {
r.nftSetName2rulesetID[nftSet.Name] = ruleset.rulesetID
}
return &ruleset
}
func (r *rulesetManager) addRule(
ruleset *nftRuleset,
ip []byte,
) (*Rule, error) {
if _, ok := r.rulesets[ruleset.rulesetID]; !ok {
return nil, fmt.Errorf("ruleset not found")
}
rule := Rule{
nftRule: ruleset.nftRule,
nftSet: ruleset.nftSet,
ruleID: xid.New().String(),
ip: ip,
}
ruleset.issuedRules[rule.ruleID] = &rule
r.issuedRuleID2rulesetID[rule.ruleID] = ruleset.rulesetID
return &rule, nil
}
// deleteRule from ruleset and returns true if contains other rules
func (r *rulesetManager) deleteRule(rule *Rule) bool {
rulesetID, ok := r.issuedRuleID2rulesetID[rule.ruleID]
if !ok {
return false
}
ruleset := r.rulesets[rulesetID]
if ruleset.nftRule == nil {
return false
}
delete(r.issuedRuleID2rulesetID, rule.ruleID)
delete(ruleset.issuedRules, rule.ruleID)
if len(ruleset.issuedRules) == 0 {
delete(r.rulesets, ruleset.rulesetID)
if rule.nftSet != nil {
delete(r.nftSetName2rulesetID, rule.nftSet.Name)
}
return false
}
return true
}
// setNftRuleHandle finds rule by userdata which contains rulesetID and updates it's handle number
//
// This is important to do, because after we add rule to the nftables we can't update it until
// we set correct handle value to it.
func (r *rulesetManager) setNftRuleHandle(nftRule *nftables.Rule) error {
split := bytes.Split(nftRule.UserData, []byte(" "))
ruleset, ok := r.rulesets[string(split[0])]
if !ok {
return fmt.Errorf("ruleset not found")
}
*ruleset.nftRule = *nftRule
return nil
}

View File

@@ -1,122 +0,0 @@
package nftables
import (
"testing"
"github.com/google/nftables"
"github.com/stretchr/testify/require"
)
func TestRulesetManager_createRuleset(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{
UserData: []byte(rulesetID),
}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
require.NotNil(t, ruleset, "createRuleset() failed")
require.Equal(t, ruleset.rulesetID, rulesetID, "rulesetID is incorrect")
require.Equal(t, ruleset.nftRule, &nftRule, "nftRule is incorrect")
}
func TestRulesetManager_addRule(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
// Add a rule to the ruleset.
ip := []byte("192.168.1.1")
rule, err := rulesetManager.addRule(ruleset, ip)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule, "rule should not be nil")
require.NotEqual(t, rule.ruleID, "ruleID is empty")
require.EqualValues(t, rule.ip, ip, "ip is incorrect")
require.Contains(t, ruleset.issuedRules, rule.ruleID, "ruleID already exists in ruleset")
require.Contains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "ruleID already exists in ruleset manager")
ruleset2 := &nftRuleset{
rulesetID: "ruleset-2",
}
_, err = rulesetManager.addRule(ruleset2, ip)
require.Error(t, err, "addRule() should have failed")
}
func TestRulesetManager_deleteRule(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
// Add a rule to the ruleset.
ip := []byte("192.168.1.1")
rule, err := rulesetManager.addRule(ruleset, ip)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule, "rule should not be nil")
ip2 := []byte("192.168.1.1")
rule2, err := rulesetManager.addRule(ruleset, ip2)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule2, "rule should not be nil")
hasNext := rulesetManager.deleteRule(rule)
require.True(t, hasNext, "deleteRule() should have returned true")
// Check that the rule is no longer in the manager.
require.NotContains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "rule should have been deleted")
hasNext = rulesetManager.deleteRule(rule2)
require.False(t, hasNext, "deleteRule() should have returned false")
}
func TestRulesetManager_setNftRuleHandle(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
// Add a rule to the ruleset.
ip := []byte("192.168.0.1")
rule, err := rulesetManager.addRule(ruleset, ip)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule, "rule should not be nil")
nftRuleCopy := nftRule
nftRuleCopy.Handle = 2
nftRuleCopy.UserData = []byte(rulesetID)
err = rulesetManager.setNftRuleHandle(&nftRuleCopy)
require.NoError(t, err, "setNftRuleHandle() failed")
// check correct work with references
require.Equal(t, nftRule.Handle, uint64(2), "nftRule.Handle is incorrect")
}
func TestRulesetManager_getRuleset(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
nftSet := nftables.Set{
ID: 2,
}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, &nftSet)
require.NotNil(t, ruleset, "createRuleset() failed")
find, ok := rulesetManager.getRuleset(rulesetID)
require.True(t, ok, "getRuleset() failed")
require.Equal(t, ruleset, find, "getRulesetBySetID() failed")
_, ok = rulesetManager.getRuleset("does-not-exist")
require.False(t, ok, "getRuleset() failed")
}

View File

@@ -1,19 +0,0 @@
//go:build !windows && !linux
package uspfilter
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
return nil
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
return nil
}

View File

@@ -1,21 +0,0 @@
package uspfilter
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
return nil
}
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if m.resetHook != nil {
return m.resetHook()
}
return nil
}

View File

@@ -1,94 +0,0 @@
package uspfilter
import (
"fmt"
"os/exec"
"syscall"
log "github.com/sirupsen/logrus"
)
type action string
const (
addRule action = "add"
deleteRule action = "delete"
firewallRuleName = "Netbird"
)
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if !isWindowsFirewallReachable() {
return nil
}
if !isFirewallRuleActive(firewallRuleName) {
return nil
}
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
return fmt.Errorf("couldn't remove windows firewall: %w", err)
}
return nil
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
if !isWindowsFirewallReachable() {
return nil
}
if isFirewallRuleActive(firewallRuleName) {
return nil
}
return manageFirewallRule(firewallRuleName,
addRule,
"dir=in",
"enable=yes",
"action=allow",
"profile=any",
"localip="+m.wgIface.Address().IP.String(),
)
}
func manageFirewallRule(ruleName string, action action, extraArgs ...string) error {
args := []string{"advfirewall", "firewall", string(action), "rule", "name=" + ruleName}
if action == addRule {
args = append(args, extraArgs...)
}
cmd := exec.Command("netsh", args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
return cmd.Run()
}
func isWindowsFirewallReachable() bool {
args := []string{"advfirewall", "show", "allprofiles", "state"}
cmd := exec.Command("netsh", args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
_, err := cmd.Output()
if err != nil {
log.Infof("Windows firewall is not reachable, skipping default rule management. Using only user space rules. Error: %s", err)
return false
}
return true
}
func isFirewallRuleActive(ruleName string) bool {
args := []string{"advfirewall", "firewall", "show", "rule", "name=" + ruleName}
cmd := exec.Command("netsh", args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
_, err := cmd.Output()
return err == nil
}

View File

@@ -19,20 +19,15 @@ const layerTypeAll = 0
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
SetFilter(iface.PacketFilter) error
Address() iface.WGAddress
}
// RuleSet is a set of rules grouped by a string key
type RuleSet map[string]Rule
// Manager userspace firewall manager
type Manager struct {
outgoingRules map[string]RuleSet
incomingRules map[string]RuleSet
outgoingRules []Rule
incomingRules []Rule
rulesIndex map[string]int
wgNetwork *net.IPNet
decoders sync.Pool
wgIface IFaceMapper
resetHook func() error
mutex sync.RWMutex
}
@@ -53,6 +48,7 @@ type decoder struct {
// Create userspace firewall manager constructor
func Create(iface IFaceMapper) (*Manager, error) {
m := &Manager{
rulesIndex: make(map[string]int),
decoders: sync.Pool{
New: func() any {
d := &decoder{
@@ -66,9 +62,6 @@ func Create(iface IFaceMapper) (*Manager, error) {
return d
},
},
outgoingRules: make(map[string]RuleSet),
incomingRules: make(map[string]RuleSet),
wgIface: iface,
}
if err := iface.SetFilter(m); err != nil {
@@ -88,7 +81,6 @@ func (m *Manager) AddFiltering(
dPort *fw.Port,
direction fw.RuleDirection,
action fw.Action,
ipsetName string,
comment string,
) (fw.Rule, error) {
r := Rule{
@@ -132,17 +124,15 @@ func (m *Manager) AddFiltering(
}
m.mutex.Lock()
var p int
if direction == fw.RuleDirectionIN {
if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(RuleSet)
}
m.incomingRules[r.ip.String()][r.id] = r
m.incomingRules = append(m.incomingRules, r)
p = len(m.incomingRules) - 1
} else {
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
m.outgoingRules[r.ip.String()] = make(RuleSet)
}
m.outgoingRules[r.ip.String()][r.id] = r
m.outgoingRules = append(m.outgoingRules, r)
p = len(m.outgoingRules) - 1
}
m.rulesIndex[r.id] = p
m.mutex.Unlock()
return &r, nil
@@ -158,25 +148,38 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
}
p, ok := m.rulesIndex[r.id]
if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.rulesIndex, r.id)
var toUpdate []Rule
if r.direction == fw.RuleDirectionIN {
_, ok := m.incomingRules[r.ip.String()][r.id]
if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.incomingRules[r.ip.String()], r.id)
m.incomingRules = append(m.incomingRules[:p], m.incomingRules[p+1:]...)
toUpdate = m.incomingRules
} else {
_, ok := m.outgoingRules[r.ip.String()][r.id]
if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.outgoingRules[r.ip.String()], r.id)
m.outgoingRules = append(m.outgoingRules[:p], m.outgoingRules[p+1:]...)
toUpdate = m.outgoingRules
}
for i := 0; i < len(toUpdate); i++ {
m.rulesIndex[toUpdate[i].id] = i
}
return nil
}
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = m.outgoingRules[:0]
m.incomingRules = m.incomingRules[:0]
m.rulesIndex = make(map[string]int)
return nil
}
// DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte) bool {
@@ -188,8 +191,8 @@ func (m *Manager) DropIncoming(packetData []byte) bool {
return m.dropFilter(packetData, m.incomingRules, true)
}
// dropFilter implements same logic for booth direction of the traffic
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool {
// dropFilter imlements same logic for booth direction of the traffic
func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket bool) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
@@ -221,49 +224,37 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isInco
log.Errorf("unknown layer: %v", d.decoded[0])
return true
}
var ip net.IP
switch ipLayer {
case layers.LayerTypeIPv4:
if isIncomingPacket {
ip = d.ip4.SrcIP
} else {
ip = d.ip4.DstIP
}
case layers.LayerTypeIPv6:
if isIncomingPacket {
ip = d.ip6.SrcIP
} else {
ip = d.ip6.DstIP
}
}
filter, ok := validateRule(ip, packetData, rules[ip.String()], d)
if ok {
return filter
}
filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d)
if ok {
return filter
}
filter, ok = validateRule(ip, packetData, rules["::"], d)
if ok {
return filter
}
// default policy is DROP ALL
return true
}
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
payloadLayer := d.decoded[1]
// check if IP address match by IP
for _, rule := range rules {
if rule.matchByIP && !ip.Equal(rule.ip) {
continue
if rule.matchByIP {
switch ipLayer {
case layers.LayerTypeIPv4:
if isIncomingPacket {
if !d.ip4.SrcIP.Equal(rule.ip) {
continue
}
} else {
if !d.ip4.DstIP.Equal(rule.ip) {
continue
}
}
case layers.LayerTypeIPv6:
if isIncomingPacket {
if !d.ip6.SrcIP.Equal(rule.ip) {
continue
}
} else {
if !d.ip6.DstIP.Equal(rule.ip) {
continue
}
}
}
}
if rule.protoLayer == layerTypeAll {
return rule.drop, true
return rule.drop
}
if payloadLayer != rule.protoLayer {
@@ -273,36 +264,38 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
switch payloadLayer {
case layers.LayerTypeTCP:
if rule.sPort == 0 && rule.dPort == 0 {
return rule.drop, true
return rule.drop
}
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
return rule.drop, true
return rule.drop
}
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
return rule.drop, true
return rule.drop
}
case layers.LayerTypeUDP:
// if rule has UDP hook (and if we are here we match this rule)
// we ignore rule.drop and call this hook
if rule.udpHook != nil {
return rule.udpHook(packetData), true
return rule.udpHook(packetData)
}
if rule.sPort == 0 && rule.dPort == 0 {
return rule.drop, true
return rule.drop
}
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
return rule.drop, true
return rule.drop
}
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
return rule.drop, true
return rule.drop
}
return rule.drop, true
return rule.drop
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return rule.drop, true
return rule.drop
}
}
return false, false
// default policy is DROP ALL
return true
}
// SetNetwork of the wireguard interface to which filtering applied
@@ -332,19 +325,19 @@ func (m *Manager) AddUDPPacketHook(
}
m.mutex.Lock()
var toUpdate []Rule
if in {
r.direction = fw.RuleDirectionIN
if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(map[string]Rule)
}
m.incomingRules[r.ip.String()][r.id] = r
m.incomingRules = append([]Rule{r}, m.incomingRules...)
toUpdate = m.incomingRules
} else {
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
m.outgoingRules[r.ip.String()] = make(map[string]Rule)
}
m.outgoingRules[r.ip.String()][r.id] = r
m.outgoingRules = append([]Rule{r}, m.outgoingRules...)
toUpdate = m.outgoingRules
}
for i := range toUpdate {
m.rulesIndex[toUpdate[i].id] = i
}
m.mutex.Unlock()
return r.id
@@ -352,26 +345,15 @@ func (m *Manager) AddUDPPacketHook(
// RemovePacketHook removes packet hook by given ID
func (m *Manager) RemovePacketHook(hookID string) error {
for _, arr := range m.incomingRules {
for _, r := range arr {
if r.id == hookID {
rule := r
return m.DeleteRule(&rule)
}
for _, r := range m.incomingRules {
if r.id == hookID {
return m.DeleteRule(&r)
}
}
for _, arr := range m.outgoingRules {
for _, r := range arr {
if r.id == hookID {
rule := r
return m.DeleteRule(&rule)
}
for _, r := range m.outgoingRules {
if r.id == hookID {
return m.DeleteRule(&r)
}
}
return fmt.Errorf("hook with given id not found")
}
// SetResetHook which will be executed in the end of Reset method
func (m *Manager) SetResetHook(hook func() error) {
m.resetHook = hook
}

View File

@@ -16,7 +16,6 @@ import (
type IFaceMock struct {
SetFilterFunc func(iface.PacketFilter) error
AddressFunc func() iface.WGAddress
}
func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
@@ -26,13 +25,6 @@ func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
return i.SetFilterFunc(iface)
}
func (i *IFaceMock) Address() iface.WGAddress {
if i.AddressFunc == nil {
return iface.WGAddress{}
}
return i.AddressFunc()
}
func TestManagerCreate(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(iface.PacketFilter) error { return nil },
@@ -71,7 +63,7 @@ func TestManagerAddFiltering(t *testing.T) {
action := fw.ActionDrop
comment := "Test rule"
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -106,7 +98,7 @@ func TestManagerDeleteRule(t *testing.T) {
action := fw.ActionDrop
comment := "Test rule"
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -119,7 +111,7 @@ func TestManagerDeleteRule(t *testing.T) {
action = fw.ActionDrop
comment = "Test rule 2"
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -131,8 +123,8 @@ func TestManagerDeleteRule(t *testing.T) {
return
}
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; !ok {
t.Errorf("rule2 is not in the incomingRules")
if idx, ok := m.rulesIndex[rule2.GetRuleID()]; !ok || len(m.incomingRules) != 1 || idx != 0 {
t.Errorf("rule2 is not in the rulesIndex")
}
err = m.DeleteRule(rule2)
@@ -141,8 +133,8 @@ func TestManagerDeleteRule(t *testing.T) {
return
}
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; ok {
t.Errorf("rule2 is not in the incomingRules")
if len(m.rulesIndex) != 0 || len(m.incomingRules) != 0 {
t.Errorf("rule1 still in the rulesIndex")
}
}
@@ -177,29 +169,26 @@ func TestAddUDPPacketHook(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := &Manager{
incomingRules: map[string]RuleSet{},
outgoingRules: map[string]RuleSet{},
incomingRules: []Rule{},
outgoingRules: []Rule{},
rulesIndex: make(map[string]int),
}
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
var addedRule Rule
if tt.in {
if len(manager.incomingRules[tt.ip.String()]) != 1 {
if len(manager.incomingRules) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
return
}
for _, rule := range manager.incomingRules[tt.ip.String()] {
addedRule = rule
}
addedRule = manager.incomingRules[0]
} else {
if len(manager.outgoingRules) != 1 {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
return
}
for _, rule := range manager.outgoingRules[tt.ip.String()] {
addedRule = rule
}
addedRule = manager.outgoingRules[0]
}
if !tt.ip.Equal(addedRule.ip) {
@@ -222,6 +211,17 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Errorf("expected udpHook to be set")
return
}
// Ensure rulesIndex is correctly updated
index, ok := manager.rulesIndex[addedRule.id]
if !ok {
t.Errorf("expected rule to be in rulesIndex")
return
}
if index != 0 {
t.Errorf("expected rule index to be 0, got %d", index)
return
}
})
}
}
@@ -244,7 +244,7 @@ func TestManagerReset(t *testing.T) {
action := fw.ActionDrop
comment := "Test rule"
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -256,7 +256,7 @@ func TestManagerReset(t *testing.T) {
return
}
if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
if len(m.rulesIndex) != 0 || len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
t.Errorf("rules is not empty")
}
}
@@ -282,7 +282,7 @@ func TestNotMatchByIP(t *testing.T) {
action := fw.ActionAccept
comment := "Test rule"
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment)
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -346,12 +346,10 @@ func TestRemovePacketHook(t *testing.T) {
// Assert the hook is added by finding it in the manager's outgoing rules
found := false
for _, arr := range manager.outgoingRules {
for _, rule := range arr {
if rule.id == hookID {
found = true
break
}
for _, rule := range manager.outgoingRules {
if rule.id == hookID {
found = true
break
}
}
@@ -366,11 +364,9 @@ func TestRemovePacketHook(t *testing.T) {
}
// Assert the hook is removed by checking it in the manager's outgoing rules
for _, arr := range manager.outgoingRules {
for _, rule := range arr {
if rule.id == hookID {
t.Fatalf("The hook was not removed properly.")
}
for _, rule := range manager.outgoingRules {
if rule.id == hookID {
t.Fatalf("The hook was not removed properly.")
}
}
}
@@ -398,9 +394,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
} else {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule")

View File

@@ -166,9 +166,10 @@ WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
EnVar::SetHKLM
EnVar::AddValueEx "path" "$INSTDIR"
SetShellVarContext all
SetShellVarContext current
CreateShortCut "$SMPROGRAMS\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
CreateShortCut "$DESKTOP\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
SetShellVarContext all
SectionEnd
Section -Post
@@ -195,9 +196,10 @@ Delete "$INSTDIR\${MAIN_APP_EXE}"
Delete "$INSTDIR\wintun.dll"
RmDir /r "$INSTDIR"
SetShellVarContext all
SetShellVarContext current
Delete "$DESKTOP\${APP_NAME}.lnk"
Delete "$SMPROGRAMS\${APP_NAME}.lnk"
SetShellVarContext all
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
@@ -207,7 +209,8 @@ SectionEnd
Function LaunchLink
SetShellVarContext all
SetShellVarContext current
SetOutPath $INSTDIR
ShellExecAsUser::ShellExecAsUser "" "$DESKTOP\${APP_NAME}.lnk"
SetShellVarContext all
FunctionEnd

View File

@@ -33,27 +33,14 @@ type Manager interface {
// DefaultManager uses firewall manager to handle
type DefaultManager struct {
manager firewall.Manager
ipsetCounter int
rulesPairs map[string][]firewall.Rule
mutex sync.Mutex
}
type ipsetInfo struct {
name string
ipCount int
}
func newDefaultManager(fm firewall.Manager) *DefaultManager {
return &DefaultManager{
manager: fm,
rulesPairs: make(map[string][]firewall.Rule),
}
manager firewall.Manager
rulesPairs map[string][]firewall.Rule
mutex sync.Mutex
}
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
//
// If allowByDefault is true it appends allow ALL traffic rules to input and output chains.
// If allowByDefault is ture it appends allow ALL traffic rules to input and output chains.
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
d.mutex.Lock()
defer d.mutex.Unlock()
@@ -74,12 +61,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
return
}
defer func() {
if err := d.manager.Flush(); err != nil {
log.Error("failed to flush firewall rules: ", err)
}
}()
rules, squashedProtocols := d.squashAcceptRules(networkMap)
enableSSH := (networkMap.PeerConfig != nil &&
@@ -127,31 +108,8 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
applyFailed := false
newRulePairs := make(map[string][]firewall.Rule)
ipsetByRuleSelectors := make(map[string]*ipsetInfo)
// calculate which IP's can be grouped in by which ipset
// to do that we use rule selector (which is just rule properties without IP's)
for _, r := range rules {
selector := d.getRuleGroupingSelector(r)
ipset, ok := ipsetByRuleSelectors[selector]
if !ok {
ipset = &ipsetInfo{}
}
ipset.ipCount++
ipsetByRuleSelectors[selector] = ipset
}
for _, r := range rules {
// if this rule is member of rule selection with more than DefaultIPsCountForSet
// it's IP address can be used in the ipset for firewall manager which supports it
ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)]
if ipset.name == "" {
d.ipsetCounter++
ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter)
}
ipsetName := ipset.name
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
pairID, rulePair, err := d.protoRuleToFirewallRule(r)
if err != nil {
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
applyFailed = true
@@ -196,10 +154,7 @@ func (d *DefaultManager) Stop() {
}
}
func (d *DefaultManager) protoRuleToFirewallRule(
r *mgmProto.FirewallRule,
ipsetName string,
) (string, []firewall.Rule, error) {
func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) (string, []firewall.Rule, error) {
ip := net.ParseIP(r.PeerIP)
if ip == nil {
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
@@ -235,9 +190,9 @@ func (d *DefaultManager) protoRuleToFirewallRule(
var err error
switch r.Direction {
case mgmProto.FirewallRule_IN:
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
rules, err = d.addInRules(ip, protocol, port, action, "")
case mgmProto.FirewallRule_OUT:
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
rules, err = d.addOutRules(ip, protocol, port, action, "")
default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
}
@@ -250,17 +205,9 @@ func (d *DefaultManager) protoRuleToFirewallRule(
return ruleID, rules, nil
}
func (d *DefaultManager) addInRules(
ip net.IP,
protocol firewall.Protocol,
port *firewall.Port,
action firewall.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) {
var rules []firewall.Rule
rule, err := d.manager.AddFiltering(
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
rule, err := d.manager.AddFiltering(ip, protocol, nil, port, firewall.RuleDirectionIN, action, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
@@ -270,8 +217,7 @@ func (d *DefaultManager) addInRules(
return rules, nil
}
rule, err = d.manager.AddFiltering(
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
rule, err = d.manager.AddFiltering(ip, protocol, port, nil, firewall.RuleDirectionOUT, action, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
@@ -279,17 +225,9 @@ func (d *DefaultManager) addInRules(
return append(rules, rule), nil
}
func (d *DefaultManager) addOutRules(
ip net.IP,
protocol firewall.Protocol,
port *firewall.Port,
action firewall.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) {
var rules []firewall.Rule
rule, err := d.manager.AddFiltering(
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
rule, err := d.manager.AddFiltering(ip, protocol, nil, port, firewall.RuleDirectionOUT, action, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
@@ -299,8 +237,7 @@ func (d *DefaultManager) addOutRules(
return rules, nil
}
rule, err = d.manager.AddFiltering(
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
rule, err = d.manager.AddFiltering(ip, protocol, port, nil, firewall.RuleDirectionIN, action, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
@@ -345,10 +282,6 @@ func (d *DefaultManager) squashAcceptRules(
in := protoMatch{}
out := protoMatch{}
// trace which type of protocols was squashed
squashedRules := []*mgmProto.FirewallRule{}
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
// this function we use to do calculation, can we squash the rules by protocol or not.
// We summ amount of Peers IP for given protocol we found in original rules list.
// But we zeroed the IP's for protocol if:
@@ -365,22 +298,12 @@ func (d *DefaultManager) squashAcceptRules(
if _, ok := protocols[r.Protocol]; !ok {
protocols[r.Protocol] = map[string]int{}
}
match := protocols[r.Protocol]
// special case, when we receive this all network IP address
// it means that rules for that protocol was already optimized on the
// management side
if r.PeerIP == "0.0.0.0" {
squashedRules = append(squashedRules, r)
squashedProtocols[r.Protocol] = struct{}{}
if _, ok := match[r.PeerIP]; ok {
return
}
ipset := protocols[r.Protocol]
if _, ok := ipset[r.PeerIP]; ok {
return
}
ipset[r.PeerIP] = i
match[r.PeerIP] = i
}
for i, r := range networkMap.FirewallRules {
@@ -393,7 +316,7 @@ func (d *DefaultManager) squashAcceptRules(
}
// order of squashing by protocol is important
// only for their first element ALL, it must be done first
// only for ther first element ALL, it must be done first
protocolOrders := []mgmProto.FirewallRuleProtocol{
mgmProto.FirewallRule_ALL,
mgmProto.FirewallRule_ICMP,
@@ -401,6 +324,9 @@ func (d *DefaultManager) squashAcceptRules(
mgmProto.FirewallRule_UDP,
}
// trace which type of protocols was squashed
squashedRules := []*mgmProto.FirewallRule{}
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
for _, protocol := range protocolOrders {
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
@@ -456,11 +382,6 @@ func (d *DefaultManager) squashAcceptRules(
return append(rules, squashedRules...), squashedProtocols
}
// getRuleGroupingSelector takes all rule properties except IP address to build selector
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
}
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol {
switch protocol {
case mgmProto.FirewallRule_TCP:

View File

@@ -1,4 +1,4 @@
//go:build !linux || android
//go:build !linux
package acl
@@ -6,8 +6,7 @@ import (
"fmt"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
)
@@ -19,10 +18,10 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
if err != nil {
return nil, err
}
if err := fm.AllowNetbird(); err != nil {
log.Warnf("failed to allow netbird interface traffic: %v", err)
}
return newDefaultManager(fm), nil
return &DefaultManager{
manager: fm,
rulesPairs: make(map[string][]firewall.Rule),
}, nil
}
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
}

View File

@@ -1,5 +1,3 @@
//go:build !android
package acl
import (
@@ -9,69 +7,30 @@ import (
"github.com/netbirdio/netbird/client/firewall/iptables"
"github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/checkfw"
)
// Create creates a firewall manager instance for the Linux
func Create(iface IFaceMapper) (*DefaultManager, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall
func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
var fm firewall.Manager
var err error
checkResult := checkfw.Check()
switch checkResult {
case checkfw.IPTABLES, checkfw.IPTABLESWITHV6:
log.Debug("creating an iptables firewall manager for access control")
ipv6Supported := checkResult == checkfw.IPTABLESWITHV6
if fm, err = iptables.Create(iface, ipv6Supported); err != nil {
log.Infof("failed to create iptables manager for access control: %s", err)
}
case checkfw.NFTABLES:
log.Debug("creating an nftables firewall manager for access control")
if fm, err = nftables.Create(iface); err != nil {
log.Debugf("failed to create nftables manager for access control: %s", err)
}
}
var resetHookForUserspace func() error
if fm != nil && err == nil {
// err shadowing is used here, to ignore this error
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
resetHookForUserspace = fm.Reset
}
if iface.IsUserspaceBind() {
// use userspace packet filtering firewall
usfm, err := uspfilter.Create(iface)
if err != nil {
if fm, err = uspfilter.Create(iface); err != nil {
log.Debugf("failed to create userspace filtering firewall: %s", err)
return nil, err
}
// set kernel space firewall Reset as hook for userspace firewall
// manager Reset method, to clean up
if resetHookForUserspace != nil {
usfm.SetResetHook(resetHookForUserspace)
} else {
if fm, err = nftables.Create(iface); err != nil {
log.Debugf("failed to create nftables manager: %s", err)
// fallback to iptables
if fm, err = iptables.Create(iface); err != nil {
log.Errorf("failed to create iptables manager: %s", err)
return nil, err
}
}
// to be consistent for any future extensions.
// ignore this error
if err := usfm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
fm = usfm
}
if fm == nil || err != nil {
log.Errorf("failed to create firewall manager: %s", err)
// no firewall manager found or initialized correctly
return nil, err
}
return newDefaultManager(fm), nil
return &DefaultManager{
manager: fm,
rulesPairs: make(map[string][]firewall.Rule),
}, nil
}

View File

@@ -1,13 +1,11 @@
package acl
import (
"net"
"testing"
"github.com/golang/mock/gomock"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/iface"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
@@ -34,22 +32,13 @@ func TestDefaultManager(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true)
ifaceMock.EXPECT().SetFilter(gomock.Any())
ip, network, err := net.ParseCIDR("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
IP: ip,
Network: network,
}).AnyTimes()
iface := mocks.NewMockIFaceMapper(ctrl)
iface.EXPECT().IsUserspaceBind().Return(true)
// iface.EXPECT().Name().Return("lo")
iface.EXPECT().SetFilter(gomock.Any())
// we receive one rule from the management so for testing purposes ignore it
acl, err := Create(ifaceMock)
acl, err := Create(iface)
if err != nil {
t.Errorf("create ACL manager: %v", err)
return
@@ -189,33 +178,31 @@ func TestDefaultManagerSquashRules(t *testing.T) {
}
r := rules[0]
switch {
case r.PeerIP != "0.0.0.0":
if r.PeerIP != "0.0.0.0" {
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
return
case r.Direction != mgmProto.FirewallRule_IN:
} else if r.Direction != mgmProto.FirewallRule_IN {
t.Errorf("direction should be IN, got: %v", r.Direction)
return
case r.Protocol != mgmProto.FirewallRule_ALL:
} else if r.Protocol != mgmProto.FirewallRule_ALL {
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return
case r.Action != mgmProto.FirewallRule_ACCEPT:
} else if r.Action != mgmProto.FirewallRule_ACCEPT {
t.Errorf("action should be ACCEPT, got: %v", r.Action)
return
}
r = rules[1]
switch {
case r.PeerIP != "0.0.0.0":
if r.PeerIP != "0.0.0.0" {
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
return
case r.Direction != mgmProto.FirewallRule_OUT:
} else if r.Direction != mgmProto.FirewallRule_OUT {
t.Errorf("direction should be OUT, got: %v", r.Direction)
return
case r.Protocol != mgmProto.FirewallRule_ALL:
} else if r.Protocol != mgmProto.FirewallRule_ALL {
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return
case r.Action != mgmProto.FirewallRule_ACCEPT:
} else if r.Action != mgmProto.FirewallRule_ACCEPT {
t.Errorf("action should be ACCEPT, got: %v", r.Action)
return
}
@@ -283,7 +270,7 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
manager := &DefaultManager{}
if rules, _ := manager.squashAcceptRules(networkMap); len(rules) != len(networkMap.FirewallRules) {
t.Errorf("we should get the same amount of rules as output, got %v", len(rules))
t.Errorf("we should got same amount of rules as intput, got %v", len(rules))
}
}
@@ -324,22 +311,13 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true)
ifaceMock.EXPECT().SetFilter(gomock.Any())
ip, network, err := net.ParseCIDR("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
IP: ip,
Network: network,
}).AnyTimes()
iface := mocks.NewMockIFaceMapper(ctrl)
iface.EXPECT().IsUserspaceBind().Return(true)
// iface.EXPECT().Name().Return("lo")
iface.EXPECT().SetFilter(gomock.Any())
// we receive one rule from the management so for testing purposes ignore it
acl, err := Create(ifaceMock)
acl, err := Create(iface)
if err != nil {
t.Errorf("create ACL manager: %v", err)
return

View File

@@ -1,203 +0,0 @@
package auth
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/netbirdio/netbird/client/internal"
)
// HostedGrantType grant type for device flow on Hosted
const (
HostedGrantType = "urn:ietf:params:oauth:grant-type:device_code"
)
var _ OAuthFlow = &DeviceAuthorizationFlow{}
// DeviceAuthorizationFlow implements the OAuthFlow interface,
// for the Device Authorization Flow.
type DeviceAuthorizationFlow struct {
providerConfig internal.DeviceAuthProviderConfig
HTTPClient HTTPClient
}
// RequestDeviceCodePayload used for request device code payload for auth0
type RequestDeviceCodePayload struct {
Audience string `json:"audience"`
ClientID string `json:"client_id"`
Scope string `json:"scope"`
}
// TokenRequestPayload used for requesting the auth0 token
type TokenRequestPayload struct {
GrantType string `json:"grant_type"`
DeviceCode string `json:"device_code,omitempty"`
ClientID string `json:"client_id"`
RefreshToken string `json:"refresh_token,omitempty"`
}
// TokenRequestResponse used for parsing Hosted token's response
type TokenRequestResponse struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
TokenInfo
}
// NewDeviceAuthorizationFlow returns device authorization flow client
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Transport: httpTransport,
}
return &DeviceAuthorizationFlow{
providerConfig: config,
HTTPClient: httpClient,
}, nil
}
// GetClientID returns the provider client id
func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
return d.providerConfig.ClientID
}
// RequestAuthInfo requests a device code login flow information from Hosted
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
form := url.Values{}
form.Add("client_id", d.providerConfig.ClientID)
form.Add("audience", d.providerConfig.Audience)
form.Add("scope", d.providerConfig.Scope)
req, err := http.NewRequest("POST", d.providerConfig.DeviceAuthEndpoint,
strings.NewReader(form.Encode()))
if err != nil {
return AuthFlowInfo{}, fmt.Errorf("creating request failed with error: %v", err)
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
res, err := d.HTTPClient.Do(req)
if err != nil {
return AuthFlowInfo{}, fmt.Errorf("doing request failed with error: %v", err)
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return AuthFlowInfo{}, fmt.Errorf("reading body failed with error: %v", err)
}
if res.StatusCode != 200 {
return AuthFlowInfo{}, fmt.Errorf("request device code returned status %d error: %s", res.StatusCode, string(body))
}
deviceCode := AuthFlowInfo{}
err = json.Unmarshal(body, &deviceCode)
if err != nil {
return AuthFlowInfo{}, fmt.Errorf("unmarshaling response failed with error: %v", err)
}
// Fallback to the verification_uri if the IdP doesn't support verification_uri_complete
if deviceCode.VerificationURIComplete == "" {
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
}
return deviceCode, err
}
func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) {
form := url.Values{}
form.Add("client_id", d.providerConfig.ClientID)
form.Add("grant_type", HostedGrantType)
form.Add("device_code", info.DeviceCode)
req, err := http.NewRequest("POST", d.providerConfig.TokenEndpoint, strings.NewReader(form.Encode()))
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("failed to create request access token: %v", err)
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
res, err := d.HTTPClient.Do(req)
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("failed to request access token with error: %v", err)
}
defer func() {
err := res.Body.Close()
if err != nil {
return
}
}()
body, err := io.ReadAll(res.Body)
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("failed reading access token response body with error: %v", err)
}
if res.StatusCode > 499 {
return TokenRequestResponse{}, fmt.Errorf("access token response returned code: %s", string(body))
}
tokenResponse := TokenRequestResponse{}
err = json.Unmarshal(body, &tokenResponse)
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("parsing token response failed with error: %v", err)
}
return tokenResponse, nil
}
// WaitToken waits user's login and authorize the app. Once the user's authorize
// it retrieves the access token from Hosted's endpoint and validates it before returning
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
interval := time.Duration(info.Interval) * time.Second
ticker := time.NewTicker(interval)
for {
select {
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case <-ticker.C:
tokenResponse, err := d.requestToken(info)
if err != nil {
return TokenInfo{}, fmt.Errorf("parsing token response failed with error: %v", err)
}
if tokenResponse.Error != "" {
if tokenResponse.Error == "authorization_pending" {
continue
} else if tokenResponse.Error == "slow_down" {
interval += (3 * time.Second)
ticker.Reset(interval)
continue
}
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
}
tokenInfo := TokenInfo{
AccessToken: tokenResponse.AccessToken,
TokenType: tokenResponse.TokenType,
RefreshToken: tokenResponse.RefreshToken,
IDToken: tokenResponse.IDToken,
ExpiresIn: tokenResponse.ExpiresIn,
UseIDToken: d.providerConfig.UseIDToken,
}
err = isValidAccessToken(tokenInfo.GetTokenToUse(), d.providerConfig.Audience)
if err != nil {
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
}
return tokenInfo, err
}
}
}

View File

@@ -1,109 +0,0 @@
package auth
import (
"context"
"fmt"
"net/http"
"runtime"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
)
// OAuthFlow represents an interface for authorization using different OAuth 2.0 flows
type OAuthFlow interface {
RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error)
WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error)
GetClientID(ctx context.Context) string
}
// HTTPClient http client interface for API calls
type HTTPClient interface {
Do(req *http.Request) (*http.Response, error)
}
// AuthFlowInfo holds information for the OAuth 2.0 authorization flow
type AuthFlowInfo struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
VerificationURI string `json:"verification_uri"`
VerificationURIComplete string `json:"verification_uri_complete"`
ExpiresIn int `json:"expires_in"`
Interval int `json:"interval"`
}
// Claims used when validating the access token
type Claims struct {
Audience interface{} `json:"aud"`
}
// TokenInfo holds information of issued access token
type TokenInfo struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
UseIDToken bool `json:"-"`
}
// GetTokenToUse returns either the access or id token based on UseIDToken field
func (t TokenInfo) GetTokenToUse() string {
if t.UseIDToken {
return t.IDToken
}
return t.AccessToken
}
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
//
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
// and if that also fails, the authentication process is deemed unsuccessful
//
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopClient bool) (OAuthFlow, error) {
if runtime.GOOS == "linux" && !isLinuxDesktopClient {
return authenticateWithDeviceCodeFlow(ctx, config)
}
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
if err != nil {
// fallback to device code flow
log.Debugf("failed to initialize pkce authentication with error: %v\n", err)
log.Debug("falling back to device code flow")
return authenticateWithDeviceCodeFlow(ctx, config)
}
return pkceFlow, nil
}
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
}
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
}
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
func authenticateWithDeviceCodeFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
switch s, ok := gstatus.FromError(err); {
case ok && s.Code() == codes.NotFound:
return nil, fmt.Errorf("no SSO provider returned from management. " +
"Please proceed with setting up this device using setup keys " +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
case ok && s.Code() == codes.Unimplemented:
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
"please update your server or use Setup Keys to login", config.ManagementURL)
default:
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
}
}
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
}

View File

@@ -1,255 +0,0 @@
package auth
import (
"context"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"errors"
"fmt"
"html/template"
"net"
"net/http"
"net/url"
"strings"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/templates"
)
var _ OAuthFlow = &PKCEAuthorizationFlow{}
const (
queryState = "state"
queryCode = "code"
queryError = "error"
queryErrorDesc = "error_description"
defaultPKCETimeoutSeconds = 300
)
// PKCEAuthorizationFlow implements the OAuthFlow interface for
// the Authorization Code Flow with PKCE.
type PKCEAuthorizationFlow struct {
providerConfig internal.PKCEAuthProviderConfig
state string
codeVerifier string
oAuthConfig *oauth2.Config
}
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
var availableRedirectURL string
// find the first available redirect URL
for _, redirectURL := range config.RedirectURLs {
if !isRedirectURLPortUsed(redirectURL) {
availableRedirectURL = redirectURL
break
}
}
if availableRedirectURL == "" {
return nil, fmt.Errorf("no available port found from configured redirect URLs: %q", config.RedirectURLs)
}
cfg := &oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: config.AuthorizationEndpoint,
TokenURL: config.TokenEndpoint,
},
RedirectURL: availableRedirectURL,
Scopes: strings.Split(config.Scope, " "),
}
return &PKCEAuthorizationFlow{
providerConfig: config,
oAuthConfig: cfg,
}, nil
}
// GetClientID returns the provider client id
func (p *PKCEAuthorizationFlow) GetClientID(_ context.Context) string {
return p.providerConfig.ClientID
}
// RequestAuthInfo requests a authorization code login flow information.
func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
state, err := randomBytesInHex(24)
if err != nil {
return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err)
}
p.state = state
codeVerifier, err := randomBytesInHex(64)
if err != nil {
return AuthFlowInfo{}, fmt.Errorf("could not create a code verifier: %v", err)
}
p.codeVerifier = codeVerifier
codeChallenge := createCodeChallenge(codeVerifier)
authURL := p.oAuthConfig.AuthCodeURL(
state,
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
)
return AuthFlowInfo{
VerificationURIComplete: authURL,
ExpiresIn: defaultPKCETimeoutSeconds,
}, nil
}
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
// Once the token is received, it is converted to TokenInfo and validated before returning.
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
tokenChan := make(chan *oauth2.Token, 1)
errChan := make(chan error, 1)
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
if err != nil {
return TokenInfo{}, fmt.Errorf("failed to parse redirect URL: %v", err)
}
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
defer func() {
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
log.Errorf("failed to close the server: %v", err)
}
}()
go p.startServer(server, tokenChan, errChan)
select {
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case token := <-tokenChan:
return p.parseOAuthToken(token)
case err := <-errChan:
return TokenInfo{}, err
}
}
func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
token, err := p.handleRequest(req)
if err != nil {
renderPKCEFlowTmpl(w, err)
errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err)
return
}
renderPKCEFlowTmpl(w, nil)
tokenChan <- token
})
server.Handler = mux
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
errChan <- err
}
}
func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token, error) {
query := req.URL.Query()
if authError := query.Get(queryError); authError != "" {
authErrorDesc := query.Get(queryErrorDesc)
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
}
// Prevent timing attacks on the state
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
return nil, fmt.Errorf("invalid state")
}
code := query.Get(queryCode)
if code == "" {
return nil, fmt.Errorf("missing code")
}
return p.oAuthConfig.Exchange(
req.Context(),
code,
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
)
}
func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo, error) {
tokenInfo := TokenInfo{
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
TokenType: token.TokenType,
ExpiresIn: token.Expiry.Second(),
UseIDToken: p.providerConfig.UseIDToken,
}
if idToken, ok := token.Extra("id_token").(string); ok {
tokenInfo.IDToken = idToken
}
// if a provider doesn't support an audience, use the Client ID for token verification
audience := p.providerConfig.Audience
if audience == "" {
audience = p.providerConfig.ClientID
}
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil {
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
}
return tokenInfo, nil
}
func createCodeChallenge(codeVerifier string) string {
sha2 := sha256.Sum256([]byte(codeVerifier))
return base64.RawURLEncoding.EncodeToString(sha2[:])
}
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use.
func isRedirectURLPortUsed(redirectURL string) bool {
parsedURL, err := url.Parse(redirectURL)
if err != nil {
log.Errorf("failed to parse redirect URL: %v", err)
return true
}
addr := fmt.Sprintf(":%s", parsedURL.Port())
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
if err != nil {
return false
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf("error while closing the connection: %v", err)
}
}()
return true
}
func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) {
tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
data := make(map[string]string)
if authError != nil {
data["Error"] = authError.Error()
}
if err := tmpl.Execute(w, data); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}

View File

@@ -1,60 +0,0 @@
package auth
import (
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"strings"
)
func randomBytesInHex(count int) (string, error) {
buf := make([]byte, count)
_, err := io.ReadFull(rand.Reader, buf)
if err != nil {
return "", fmt.Errorf("could not generate %d random bytes: %v", count, err)
}
return hex.EncodeToString(buf), nil
}
// isValidAccessToken is a simple validation of the access token
func isValidAccessToken(token string, audience string) error {
if token == "" {
return fmt.Errorf("token received is empty")
}
encodedClaims := strings.Split(token, ".")[1]
claimsString, err := base64.RawURLEncoding.DecodeString(encodedClaims)
if err != nil {
return err
}
claims := Claims{}
err = json.Unmarshal(claimsString, &claims)
if err != nil {
return err
}
if claims.Audience == nil {
return fmt.Errorf("required token field audience is absent")
}
// Audience claim of JWT can be a string or an array of strings
switch aud := claims.Audience.(type) {
case string:
if aud == audience {
return nil
}
case []interface{}:
for _, audItem := range aud {
if audStr, ok := audItem.(string); ok && audStr == audience {
return nil
}
}
}
return fmt.Errorf("invalid JWT token audience field")
}

View File

@@ -1,3 +0,0 @@
//go:build !linux || android
package checkfw

View File

@@ -1,56 +0,0 @@
//go:build !android
package checkfw
import (
"os"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
)
const (
// UNKNOWN is the default value for the firewall type for unknown firewall type
UNKNOWN FWType = iota
// IPTABLES is the value for the iptables firewall type
IPTABLES
// IPTABLESWITHV6 is the value for the iptables firewall type with ipv6
IPTABLESWITHV6
// NFTABLES is the value for the nftables firewall type
NFTABLES
)
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int
// Check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func Check() FWType {
nf := nftables.Conn{}
if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
return NFTABLES
}
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err == nil {
if isIptablesClientAvailable(ip) {
ipSupport := IPTABLES
ipv6, ip6Err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if ip6Err == nil {
if isIptablesClientAvailable(ipv6) {
ipSupport = IPTABLESWITHV6
}
}
return ipSupport
}
}
return UNKNOWN
}
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}

View File

@@ -215,12 +215,10 @@ func update(input ConfigInput) (*Config, error) {
}
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
if *input.PreSharedKey != "" {
log.Infof("new pre-shared key provides, updated to %s (old value %s)",
*input.PreSharedKey, config.PreSharedKey)
config.PreSharedKey = *input.PreSharedKey
refresh = true
}
log.Infof("new pre-shared key provided, updated to %s (old value %s)",
*input.PreSharedKey, config.PreSharedKey)
config.PreSharedKey = *input.PreSharedKey
refresh = true
}
if config.SSHKey == "" {
@@ -273,9 +271,9 @@ func parseURL(serviceName, serviceURL string) (*url.URL, error) {
if parsedMgmtURL.Port() == "" {
switch parsedMgmtURL.Scheme {
case "https":
parsedMgmtURL.Host += ":443"
parsedMgmtURL.Host = parsedMgmtURL.Host + ":443"
case "http":
parsedMgmtURL.Host += ":80"
parsedMgmtURL.Host = parsedMgmtURL.Host + ":80"
default:
log.Infof("unable to determine a default port for schema %s in URL %s", parsedMgmtURL.Scheme, serviceURL)
}

View File

@@ -23,6 +23,9 @@ func TestGetConfig(t *testing.T) {
assert.Equal(t, config.ManagementURL.String(), DefaultManagementURL)
assert.Equal(t, config.AdminURL.String(), DefaultAdminURL)
if err != nil {
return
}
managementURL := "https://test.management.url:33071"
adminURL := "https://app.admin.url:443"
path := filepath.Join(t.TempDir(), "config.json")
@@ -60,22 +63,7 @@ func TestGetConfig(t *testing.T) {
assert.Equal(t, config.ManagementURL.String(), managementURL)
assert.Equal(t, config.PreSharedKey, preSharedKey)
// case 4: new empty pre-shared key config -> fetch it
newPreSharedKey := ""
config, err = UpdateOrCreateConfig(ConfigInput{
ManagementURL: managementURL,
AdminURL: adminURL,
ConfigPath: path,
PreSharedKey: &newPreSharedKey,
})
if err != nil {
return
}
assert.Equal(t, config.ManagementURL.String(), managementURL)
assert.Equal(t, config.PreSharedKey, preSharedKey)
// case 5: existing config, but new managementURL has been provided -> update config
// case 4: existing config, but new managementURL has been provided -> update config
newManagementURL := "https://test.newManagement.url:33071"
config, err = UpdateOrCreateConfig(ConfigInput{
ManagementURL: newManagementURL,

View File

@@ -12,9 +12,8 @@ import (
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
@@ -22,40 +21,10 @@ import (
mgm "github.com/netbirdio/netbird/management/client"
mgmProto "github.com/netbirdio/netbird/management/proto"
signal "github.com/netbirdio/netbird/signal/client"
"github.com/netbirdio/netbird/version"
)
// RunClient with main logic.
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
return runClient(ctx, config, statusRecorder, MobileDependency{})
}
// RunClientMobile with main logic on mobile system
func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, networkChangeListener listener.NetworkChangeListener, dnsAddresses []string, dnsReadyListener dns.ReadyListener) error {
// in case of non Android os these variables will be nil
mobileDependency := MobileDependency{
TunAdapter: tunAdapter,
IFaceDiscover: iFaceDiscover,
NetworkChangeListener: networkChangeListener,
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
return runClient(ctx, config, statusRecorder, mobileDependency)
}
func RunClientiOS(ctx context.Context, config *Config, statusRecorder *peer.Status, fileDescriptor int32, networkChangeListener listener.NetworkChangeListener, dnsManager dns.IosDnsManager, interfaceName string) error {
mobileDependency := MobileDependency{
FileDescriptor: fileDescriptor,
InterfaceName: interfaceName,
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
}
return runClient(ctx, config, statusRecorder, mobileDependency)
}
func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status, mobileDependency MobileDependency) error {
log.Infof("starting NetBird client version %s", version.NetbirdVersion())
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener) error {
backOff := &backoff.ExponentialBackOff{
InitialInterval: time.Second,
RandomizationFactor: 1,
@@ -109,7 +78,7 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
cancel()
}()
log.Debugf("connecting to the Management service %s", config.ManagementURL.Host)
log.Debugf("conecting to the Management service %s", config.ManagementURL.Host)
mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
if err != nil {
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
@@ -182,7 +151,14 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
return wrapErr(err)
}
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder)
// in case of non Android os these variables will be nil
md := MobileDependency{
TunAdapter: tunAdapter,
IFaceDiscover: iFaceDiscover,
RouteListener: routeListener,
}
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, md, statusRecorder)
err = engine.Start()
if err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err)
@@ -192,6 +168,8 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
state.Set(StatusConnected)
statusRecorder.ClientStart()
<-engineCtx.Done()
statusRecorder.ClientTeardown()
@@ -212,7 +190,6 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
return nil
}
statusRecorder.ClientStart()
err = backoff.Retry(operation, backOff)
if err != nil {
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)

View File

@@ -16,11 +16,11 @@ import (
// DeviceAuthorizationFlow represents Device Authorization Flow information
type DeviceAuthorizationFlow struct {
Provider string
ProviderConfig DeviceAuthProviderConfig
ProviderConfig ProviderConfig
}
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
type DeviceAuthProviderConfig struct {
// ProviderConfig has all attributes needed to initiate a device authorization flow
type ProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
@@ -88,7 +88,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
deviceAuthorizationFlow := DeviceAuthorizationFlow{
Provider: protoDeviceAuthorizationFlow.Provider.String(),
ProviderConfig: DeviceAuthProviderConfig{
ProviderConfig: ProviderConfig{
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
@@ -105,7 +105,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
}
err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
err = isProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
if err != nil {
return DeviceAuthorizationFlow{}, err
}
@@ -113,7 +113,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
return deviceAuthorizationFlow, nil
}
func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
func isProviderConfigValid(config ProviderConfig) error {
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.Audience == "" {
return fmt.Errorf(errorMSGFormat, "Audience")

View File

@@ -3,26 +3,29 @@
package dns
import (
"bufio"
"bytes"
"fmt"
"os"
"strings"
log "github.com/sirupsen/logrus"
)
const (
fileGeneratedResolvConfContentHeader = "# Generated by NetBird"
fileGeneratedResolvConfContentHeaderNextLine = fileGeneratedResolvConfContentHeader + `
# If needed you can restore the original file by copying back ` + fileDefaultResolvConfBackupLocation + "\n\n"
fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"
fileMaxLineCharsLimit = 256
fileMaxNumberOfSearchDomains = 6
fileGeneratedResolvConfContentHeader = "# Generated by NetBird"
fileGeneratedResolvConfSearchBeginContent = "search "
fileGeneratedResolvConfContentFormat = fileGeneratedResolvConfContentHeader +
"\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" +
fileGeneratedResolvConfSearchBeginContent + "%s\n"
)
const (
fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"
fileMaxLineCharsLimit = 256
fileMaxNumberOfSearchDomains = 6
)
var fileSearchLineBeginCharCount = len(fileGeneratedResolvConfSearchBeginContent)
type fileConfigurator struct {
originalPerms os.FileMode
}
@@ -51,39 +54,53 @@ func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error {
}
return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
}
if !backupFileExist {
err = f.backup()
if err != nil {
return fmt.Errorf("unable to backup the resolv.conf file")
managerType, err := getOSDNSManagerType()
if err != nil {
return err
}
switch managerType {
case fileManager, netbirdManager:
if !backupFileExist {
err = f.backup()
if err != nil {
return fmt.Errorf("unable to backup the resolv.conf file")
}
}
default:
// todo improve this and maybe restart DNS manager from scratch
return fmt.Errorf("something happened and file manager is not your prefered host dns configurator, restart the agent")
}
searchDomainList := searchDomains(config)
var searchDomains string
appendedDomains := 0
for _, dConf := range config.domains {
if dConf.matchOnly || dConf.disabled {
continue
}
if appendedDomains >= fileMaxNumberOfSearchDomains {
// lets log all skipped domains
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, dConf.domain)
continue
}
if fileSearchLineBeginCharCount+len(searchDomains) > fileMaxLineCharsLimit {
// lets log all skipped domains
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, dConf.domain)
continue
}
originalSearchDomains, nameServers, others, err := originalDNSConfigs(fileDefaultResolvConfBackupLocation)
if err != nil {
log.Error(err)
searchDomains += " " + dConf.domain
appendedDomains++
}
searchDomainList = mergeSearchDomains(searchDomainList, originalSearchDomains)
buf := prepareResolvConfContent(
searchDomainList,
append([]string{config.serverIP}, nameServers...),
others)
log.Debugf("creating managed file %s", defaultResolvConfPath)
err = os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms)
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains)
err = writeDNSConfig(content, defaultResolvConfPath, f.originalPerms)
if err != nil {
restoreErr := f.restore()
if restoreErr != nil {
err = f.restore()
if err != nil {
log.Errorf("attempt to restore default file failed with error: %s", err)
}
return fmt.Errorf("got an creating resolver file %s. Error: %s", defaultResolvConfPath, err)
return err
}
log.Infof("created a NetBird managed %s file with your DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
log.Infof("created a NetBird managed %s file with your DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, appendedDomains, searchDomains)
return nil
}
@@ -115,138 +132,15 @@ func (f *fileConfigurator) restore() error {
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
}
func prepareResolvConfContent(searchDomains, nameServers, others []string) bytes.Buffer {
func writeDNSConfig(content, fileName string, permissions os.FileMode) error {
log.Debugf("creating managed file %s", fileName)
var buf bytes.Buffer
buf.WriteString(fileGeneratedResolvConfContentHeaderNextLine)
for _, cfgLine := range others {
buf.WriteString(cfgLine)
buf.WriteString("\n")
}
if len(searchDomains) > 0 {
buf.WriteString("search ")
buf.WriteString(strings.Join(searchDomains, " "))
buf.WriteString("\n")
}
for _, ns := range nameServers {
buf.WriteString("nameserver ")
buf.WriteString(ns)
buf.WriteString("\n")
}
return buf
}
func searchDomains(config hostDNSConfig) []string {
listOfDomains := make([]string, 0)
for _, dConf := range config.domains {
if dConf.matchOnly || dConf.disabled {
continue
}
listOfDomains = append(listOfDomains, dConf.domain)
}
return listOfDomains
}
func originalDNSConfigs(resolvconfFile string) (searchDomains, nameServers, others []string, err error) {
file, err := os.Open(resolvconfFile)
buf.WriteString(content)
err := os.WriteFile(fileName, buf.Bytes(), permissions)
if err != nil {
err = fmt.Errorf(`could not read existing resolv.conf`)
return
return fmt.Errorf("got an creating resolver file %s. Error: %s", fileName, err)
}
defer file.Close()
reader := bufio.NewReader(file)
for {
lineBytes, isPrefix, readErr := reader.ReadLine()
if readErr != nil {
break
}
if isPrefix {
err = fmt.Errorf(`resolv.conf line too long`)
return
}
line := strings.TrimSpace(string(lineBytes))
if strings.HasPrefix(line, "#") {
continue
}
if strings.HasPrefix(line, "domain") {
continue
}
if strings.HasPrefix(line, "options") && strings.Contains(line, "rotate") {
line = strings.ReplaceAll(line, "rotate", "")
splitLines := strings.Fields(line)
if len(splitLines) == 1 {
continue
}
line = strings.Join(splitLines, " ")
}
if strings.HasPrefix(line, "search") {
splitLines := strings.Fields(line)
if len(splitLines) < 2 {
continue
}
searchDomains = splitLines[1:]
continue
}
if strings.HasPrefix(line, "nameserver") {
splitLines := strings.Fields(line)
if len(splitLines) != 2 {
continue
}
nameServers = append(nameServers, splitLines[1])
continue
}
others = append(others, line)
}
return
}
// merge search domains lists and cut off the list if it is too long
func mergeSearchDomains(searchDomains []string, originalSearchDomains []string) []string {
lineSize := len("search")
searchDomainsList := make([]string, 0, len(searchDomains)+len(originalSearchDomains))
lineSize = validateAndFillSearchDomains(lineSize, &searchDomainsList, searchDomains)
_ = validateAndFillSearchDomains(lineSize, &searchDomainsList, originalSearchDomains)
return searchDomainsList
}
// validateAndFillSearchDomains checks if the search domains list is not too long and if the line is not too long
// extend s slice with vs elements
// return with the number of characters in the searchDomains line
func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string) int {
for _, sd := range vs {
tmpCharsNumber := initialLineChars + 1 + len(sd)
if tmpCharsNumber > fileMaxLineCharsLimit {
// lets log all skipped domains
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, sd)
continue
}
initialLineChars = tmpCharsNumber
if len(*s) >= fileMaxNumberOfSearchDomains {
// lets log all skipped domains
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, sd)
continue
}
*s = append(*s, sd)
}
return initialLineChars
return nil
}
func copyFile(src, dest string) error {

View File

@@ -1,62 +0,0 @@
package dns
import (
"fmt"
"testing"
)
func Test_mergeSearchDomains(t *testing.T) {
searchDomains := []string{"a", "b"}
originDomains := []string{"a", "b"}
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
if len(mergedDomains) != 4 {
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 4)
}
}
func Test_mergeSearchTooMuchDomains(t *testing.T) {
searchDomains := []string{"a", "b", "c", "d", "e", "f", "g"}
originDomains := []string{"h", "i"}
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
if len(mergedDomains) != 6 {
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 6)
}
}
func Test_mergeSearchTooMuchDomainsInOrigin(t *testing.T) {
searchDomains := []string{"a", "b"}
originDomains := []string{"c", "d", "e", "f", "g"}
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
if len(mergedDomains) != 6 {
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 6)
}
}
func Test_mergeSearchTooLongDomain(t *testing.T) {
searchDomains := []string{getLongLine()}
originDomains := []string{"b"}
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
if len(mergedDomains) != 1 {
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 1)
}
searchDomains = []string{"b"}
originDomains = []string{getLongLine()}
mergedDomains = mergeSearchDomains(searchDomains, originDomains)
if len(mergedDomains) != 1 {
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 1)
}
}
func getLongLine() string {
x := "search "
for {
for i := 0; i <= 9; i++ {
if len(x) > fileMaxLineCharsLimit {
return x
}
x = fmt.Sprintf("%s%d", x, i)
}
}
}

View File

@@ -61,7 +61,7 @@ func newNoopHostMocker() hostManager {
}
}
func dnsConfigTohostDNSConfig(dnsConfig nbdns.Config, ip string, port int) hostDNSConfig {
func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) hostDNSConfig {
config := hostDNSConfig{
routeAll: false,
serverIP: ip,
@@ -78,7 +78,7 @@ func dnsConfigTohostDNSConfig(dnsConfig nbdns.Config, ip string, port int) hostD
for _, domain := range nsConfig.Domains {
config.domains = append(config.domains, domainConfig{
domain: strings.TrimSuffix(domain, "."),
matchOnly: !nsConfig.SearchDomainsEnabled,
matchOnly: true,
})
}
}

View File

@@ -1,9 +1,13 @@
package dns
import (
"github.com/netbirdio/netbird/iface"
)
type androidHostManager struct {
}
func newHostManager() (hostManager, error) {
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
return &androidHostManager{}, nil
}

View File

@@ -1,5 +1,3 @@
//go:build !ios
package dns
import (
@@ -11,6 +9,8 @@ import (
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/iface"
)
const (
@@ -34,7 +34,7 @@ type systemConfigurator struct {
createdKeys map[string]struct{}
}
func newHostManager() (hostManager, error) {
func newHostManager(_ *iface.WGIface) (hostManager, error) {
return &systemConfigurator{
createdKeys: make(map[string]struct{}),
}, nil
@@ -184,11 +184,12 @@ func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port
}
func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error {
primaryServiceKey, existingNameserver := s.getPrimaryService()
primaryServiceKey := s.getPrimaryService()
if primaryServiceKey == "" {
return fmt.Errorf("couldn't find the primary service key")
}
err := s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver)
err := s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port)
if err != nil {
return err
}
@@ -197,32 +198,27 @@ func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error
return nil
}
func (s *systemConfigurator) getPrimaryService() (string, string) {
func (s *systemConfigurator) getPrimaryService() string {
line := buildCommandLine("show", globalIPv4State, "")
stdinCommands := wrapCommand(line)
b, err := runSystemConfigCommand(stdinCommands)
if err != nil {
log.Error("got error while sending the command: ", err)
return "", ""
return ""
}
scanner := bufio.NewScanner(bytes.NewReader(b))
primaryService := ""
router := ""
for scanner.Scan() {
text := scanner.Text()
if strings.Contains(text, "PrimaryService") {
primaryService = strings.TrimSpace(strings.Split(text, ":")[1])
}
if strings.Contains(text, "Router") {
router = strings.TrimSpace(strings.Split(text, ":")[1])
return strings.TrimSpace(strings.Split(text, ":")[1])
}
}
return primaryService, router
return ""
}
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, existingDNSServer string) error {
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int) error {
lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0))
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer+" "+existingDNSServer)
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer)
lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port))
addDomainCommand := buildCreateStateWithOperation(setupKey, lines)
stdinCommands := wrapCommand(addDomainCommand)

View File

@@ -1,45 +0,0 @@
package dns
import (
"strconv"
"strings"
)
type iosHostManager struct {
dnsManager IosDnsManager
config hostDNSConfig
}
func newHostManager(dnsManager IosDnsManager) (hostManager, error) {
return &iosHostManager{
dnsManager: dnsManager,
}, nil
}
func (a iosHostManager) applyDNSConfig(config hostDNSConfig) error {
var configAsString []string
configAsString = append(configAsString, config.serverIP)
configAsString = append(configAsString, strconv.Itoa(config.serverPort))
configAsString = append(configAsString, strconv.FormatBool(config.routeAll))
var domainConfigAsString []string
for _, domain := range config.domains {
var domainAsString []string
domainAsString = append(domainAsString, strconv.FormatBool(domain.disabled))
domainAsString = append(domainAsString, domain.domain)
domainAsString = append(domainAsString, strconv.FormatBool(domain.matchOnly))
domainConfigAsString = append(domainConfigAsString, strings.Join(domainAsString, "|"))
}
domainConfig := strings.Join(domainConfigAsString, ";")
configAsString = append(configAsString, domainConfig)
outputString := strings.Join(configAsString, ",")
a.dnsManager.ApplyDns(outputString)
return nil
}
func (a iosHostManager) restoreHostDNS() error {
return nil
}
func (a iosHostManager) supportCustomPort() bool {
return false
}

View File

@@ -5,10 +5,10 @@ package dns
import (
"bufio"
"fmt"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
"os"
"strings"
log "github.com/sirupsen/logrus"
)
const (
@@ -25,7 +25,7 @@ const (
type osManagerType int
func newHostManager(wgInterface WGIface) (hostManager, error) {
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
osManager, err := getOSDNSManagerType()
if err != nil {
return nil, err

View File

@@ -6,6 +6,8 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows/registry"
"github.com/netbirdio/netbird/iface"
)
const (
@@ -22,14 +24,16 @@ const (
interfaceConfigPath = "SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Interfaces"
interfaceConfigNameServerKey = "NameServer"
interfaceConfigSearchListKey = "SearchList"
tcpipParametersPath = "SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters"
)
type registryConfigurator struct {
guid string
routingAll bool
guid string
routingAll bool
existingSearchDomains []string
}
func newHostManager(wgInterface WGIface) (hostManager, error) {
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
guid, err := wgInterface.GetInterfaceGUIDString()
if err != nil {
return nil, err
@@ -146,11 +150,30 @@ func (r *registryConfigurator) restoreHostDNS() error {
log.Error(err)
}
return r.deleteInterfaceRegistryKeyProperty(interfaceConfigSearchListKey)
return r.updateSearchDomains([]string{})
}
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ","))
value, err := getLocalMachineRegistryKeyStringValue(tcpipParametersPath, interfaceConfigSearchListKey)
if err != nil {
return fmt.Errorf("unable to get current search domains failed with error: %s", err)
}
valueList := strings.Split(value, ",")
setExisting := false
if len(r.existingSearchDomains) == 0 {
r.existingSearchDomains = valueList
setExisting = true
}
if len(domains) == 0 && setExisting {
log.Infof("added %d search domains to the registry. Domain list: %s", len(domains), domains)
return nil
}
newList := append(r.existingSearchDomains, domains...)
err = setLocalMachineRegistryKeyStringValue(tcpipParametersPath, interfaceConfigSearchListKey, strings.Join(newList, ","))
if err != nil {
return fmt.Errorf("adding search domain failed with error: %s", err)
}
@@ -214,3 +237,33 @@ func removeRegistryKeyFromDNSPolicyConfig(regKeyPath string) error {
}
return nil
}
func getLocalMachineRegistryKeyStringValue(keyPath, key string) (string, error) {
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE)
if err != nil {
return "", fmt.Errorf("unable to open existing key from registry, key path: HKEY_LOCAL_MACHINE\\%s, error: %s", keyPath, err)
}
defer regKey.Close()
val, _, err := regKey.GetStringValue(key)
if err != nil {
return "", fmt.Errorf("getting %s value for key path HKEY_LOCAL_MACHINE\\%s failed with error: %s", key, keyPath, err)
}
return val, nil
}
func setLocalMachineRegistryKeyStringValue(keyPath, key, value string) error {
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.SET_VALUE)
if err != nil {
return fmt.Errorf("unable to open existing key from registry, key path: HKEY_LOCAL_MACHINE\\%s, error: %s", keyPath, err)
}
defer regKey.Close()
err = regKey.SetStringValue(key, value)
if err != nil {
return fmt.Errorf("setting %s value %s for key path HKEY_LOCAL_MACHINE\\%s failed with error: %s", key, value, keyPath, err)
}
return nil
}

View File

@@ -2,7 +2,6 @@ package dns
import (
"fmt"
nbdns "github.com/netbirdio/netbird/dns"
)
@@ -14,7 +13,7 @@ type MockServer struct {
}
// Initialize mock implementation of Initialize from Server interface
func (m *MockServer) Initialize(manager IosDnsManager) error {
func (m *MockServer) Initialize() error {
if m.InitializeFunc != nil {
return m.InitializeFunc()
}
@@ -32,11 +31,6 @@ func (m *MockServer) DnsIP() string {
return ""
}
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) {
// TODO implement me
panic("implement me")
}
// UpdateDNSServer mock implementation of UpdateDNSServer from Server interface
func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
if m.UpdateDNSServerFunc != nil {
@@ -44,7 +38,3 @@ func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
}
return fmt.Errorf("method UpdateDNSServer is not implemented")
}
func (m *MockServer) SearchDomains() []string {
return make([]string, 0)
}

View File

@@ -7,6 +7,7 @@ import (
"encoding/binary"
"fmt"
"net/netip"
"regexp"
"time"
"github.com/godbus/dbus/v5"
@@ -14,7 +15,7 @@ import (
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nbversion "github.com/netbirdio/netbird/version"
"github.com/netbirdio/netbird/iface"
)
const (
@@ -71,7 +72,7 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() {
}
}
func newNetworkManagerDbusConfigurator(wgInterface WGIface) (hostManager, error) {
func newNetworkManagerDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
if err != nil {
return nil, err
@@ -123,7 +124,7 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) er
searchDomains = append(searchDomains, dns.Fqdn(dConf.domain))
}
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic
newDomainList := append(searchDomains, matchDomains...)
priority := networkManagerDbusSearchDomainOnlyPriority
switch {
@@ -290,7 +291,12 @@ func isNetworkManagerSupportedVersion() bool {
}
func parseVersion(inputVersion string) (*version.Version, error) {
if inputVersion == "" || !nbversion.SemverRegexp.MatchString(inputVersion) {
reg, err := regexp.Compile(version.SemverRegexpRaw)
if err != nil {
return nil, err
}
if inputVersion == "" || !reg.MatchString(inputVersion) {
return nil, fmt.Errorf("couldn't parse the provided version: Not SemVer")
}

View File

@@ -1,57 +0,0 @@
package dns
import (
"reflect"
"sort"
"sync"
"github.com/netbirdio/netbird/client/internal/listener"
)
type notifier struct {
listener listener.NetworkChangeListener
listenerMux sync.Mutex
searchDomains []string
}
func newNotifier(initialSearchDomains []string) *notifier {
sort.Strings(initialSearchDomains)
return &notifier{
searchDomains: initialSearchDomains,
}
}
func (n *notifier) setListener(listener listener.NetworkChangeListener) {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
n.listener = listener
}
func (n *notifier) onNewSearchDomains(searchDomains []string) {
sort.Strings(searchDomains)
if len(n.searchDomains) != len(searchDomains) {
n.searchDomains = searchDomains
n.notify()
return
}
if reflect.DeepEqual(n.searchDomains, searchDomains) {
return
}
n.searchDomains = searchDomains
n.notify()
}
func (n *notifier) notify() {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
if n.listener == nil {
return
}
go func(l listener.NetworkChangeListener) {
l.OnNetworkChanged("")
}(n.listener)
}

View File

@@ -3,35 +3,24 @@
package dns
import (
"bytes"
"fmt"
"os/exec"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/iface"
)
const resolvconfCommand = "resolvconf"
type resolvconf struct {
ifaceName string
originalSearchDomains []string
originalNameServers []string
othersConfigs []string
}
// supported "openresolv" only
func newResolvConfConfigurator(wgInterface WGIface) (hostManager, error) {
originalSearchDomains, nameServers, others, err := originalDNSConfigs("/etc/resolv.conf")
if err != nil {
log.Error(err)
}
func newResolvConfConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
return &resolvconf{
ifaceName: wgInterface.Name(),
originalSearchDomains: originalSearchDomains,
originalNameServers: nameServers,
othersConfigs: others,
ifaceName: wgInterface.Name(),
}, nil
}
@@ -49,20 +38,37 @@ func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error {
return fmt.Errorf("unable to configure DNS for this peer using resolvconf manager without a nameserver group with all domains configured")
}
searchDomainList := searchDomains(config)
searchDomainList = mergeSearchDomains(searchDomainList, r.originalSearchDomains)
var searchDomains string
appendedDomains := 0
for _, dConf := range config.domains {
if dConf.matchOnly || dConf.disabled {
continue
}
buf := prepareResolvConfContent(
searchDomainList,
append([]string{config.serverIP}, r.originalNameServers...),
r.othersConfigs)
if appendedDomains >= fileMaxNumberOfSearchDomains {
// lets log all skipped domains
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, dConf.domain)
continue
}
err = r.applyConfig(buf)
if fileSearchLineBeginCharCount+len(searchDomains) > fileMaxLineCharsLimit {
// lets log all skipped domains
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, dConf.domain)
continue
}
searchDomains += " " + dConf.domain
appendedDomains++
}
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains)
err = r.applyConfig(content)
if err != nil {
return err
}
log.Infof("added %d search domains. Search list: %s", len(searchDomainList), searchDomainList)
log.Infof("added %d search domains. Search list: %s", appendedDomains, searchDomains)
return nil
}
@@ -75,12 +81,12 @@ func (r *resolvconf) restoreHostDNS() error {
return nil
}
func (r *resolvconf) applyConfig(content bytes.Buffer) error {
func (r *resolvconf) applyConfig(content string) error {
cmd := exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName)
cmd.Stdin = &content
cmd.Stdin = strings.NewReader(content)
_, err := cmd.Output()
if err != nil {
return fmt.Errorf("got an error while applying resolvconf configuration for %s interface, error: %s", r.ifaceName, err)
return fmt.Errorf("got an error while appying resolvconf configuration for %s interface, error: %s", r.ifaceName, err)
}
return nil
}

View File

@@ -3,35 +3,36 @@ package dns
import (
"context"
"fmt"
"math/big"
"net"
"net/netip"
"runtime"
"sync"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/listener"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
)
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
type ReadyListener interface {
OnReady()
}
// IosDnsManager is a dns manager interface for iosß
type IosDnsManager interface {
ApplyDns(string)
}
const (
defaultPort = 53
customPort = 5053
defaultIP = "127.0.0.1"
customIP = "127.0.0.153"
)
// Server is a dns server interface
type Server interface {
Initialize(manager IosDnsManager) error
Initialize() error
Stop()
DnsIP() string
UpdateDNSServer(serial uint64, update nbdns.Config) error
OnUpdatedHostDNSServer(strings []string)
SearchDomains() []string
}
type registeredHandlerMap map[string]handlerWithStop
@@ -41,25 +42,21 @@ type DefaultServer struct {
ctx context.Context
ctxCancel context.CancelFunc
mux sync.Mutex
service service
udpFilterHookID string
server *dns.Server
dnsMux *dns.ServeMux
dnsMuxMap registeredHandlerMap
localResolver *localResolver
wgInterface WGIface
wgInterface *iface.WGIface
hostManager hostManager
updateSerial uint64
listenerIsRunning bool
runtimePort int
runtimeIP string
previousConfigHash uint64
currentConfig hostDNSConfig
// permanent related properties
permanent bool
hostsDnsList []string
hostsDnsListLock sync.Mutex
interfaceName string
wgAddr string
// make sense on mobile only
searchDomainNotifier *notifier
customAddress *netip.AddrPort
enabled bool
}
type handlerWithStop interface {
@@ -73,7 +70,9 @@ type muxUpdate struct {
}
// NewDefaultServer returns a new dns server
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string, interfaceName string, wgAddr string) (*DefaultServer, error) {
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string, initialDnsCfg *nbdns.Config) (*DefaultServer, error) {
mux := dns.NewServeMux()
var addrPort *netip.AddrPort
if customAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
@@ -83,49 +82,37 @@ func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress st
addrPort = &parsedAddrPort
}
var dnsService service
if wgInterface.IsUserspaceBind() {
dnsService = newServiceViaMemory(wgInterface)
} else {
dnsService = newServiceViaListener(wgInterface, addrPort)
}
return newDefaultServer(ctx, wgInterface, dnsService, interfaceName, wgAddr), nil
}
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string, config nbdns.Config, listener listener.NetworkChangeListener) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), "", "")
ds.permanent = true
ds.hostsDnsList = hostsDnsList
ds.addHostRootZone()
ds.currentConfig = dnsConfigTohostDNSConfig(config, ds.service.RuntimeIP(), ds.service.RuntimePort())
ds.searchDomainNotifier = newNotifier(ds.SearchDomains())
ds.searchDomainNotifier.setListener(listener)
setServerDns(ds)
return ds
}
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, interfaceName string, wgAddr string) *DefaultServer {
ctx, stop := context.WithCancel(ctx)
defaultServer := &DefaultServer{
ctx: ctx,
ctxCancel: stop,
service: dnsService,
server: &dns.Server{
Net: "udp",
Handler: mux,
UDPSize: 65535,
},
dnsMux: mux,
dnsMuxMap: make(registeredHandlerMap),
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
wgInterface: wgInterface,
interfaceName: interfaceName,
wgAddr: wgAddr,
customAddress: addrPort,
}
return defaultServer
if initialDnsCfg != nil {
defaultServer.enabled = hasValidDnsServer(initialDnsCfg)
}
if wgInterface.IsUserspaceBind() {
defaultServer.evelRuntimeAddressForUserspace()
}
return defaultServer, nil
}
// Initialize instantiate host manager and the dns service
// Initialize instantiate host manager. It required to be initialized wginterface
func (s *DefaultServer) Initialize() (err error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -134,8 +121,33 @@ func (s *DefaultServer) Initialize() (err error) {
return nil
}
s.hostManager, err = s.initialize()
return err
if !s.wgInterface.IsUserspaceBind() {
s.evalRuntimeAddress()
}
s.hostManager, err = newHostManager(s.wgInterface)
return
}
// listen runs the listener in a go routine
func (s *DefaultServer) listen() {
// nil check required in unit tests
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
s.udpFilterHookID = s.filterDNSTraffic()
s.setListenerStatus(true)
return
}
log.Debugf("starting dns on %s", s.server.Addr)
go func() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)
err := s.server.ListenAndServe()
if err != nil {
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
}
}()
}
// DnsIP returns the DNS resolver server IP address
@@ -143,7 +155,38 @@ func (s *DefaultServer) Initialize() (err error) {
// When kernel space interface used it return real DNS server listener IP address
// For bind interface, fake DNS resolver address returned (second last IP address from Nebird network)
func (s *DefaultServer) DnsIP() string {
return s.service.RuntimeIP()
if !s.enabled {
return ""
}
return s.runtimeIP
}
func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) {
ips := []string{defaultIP, customIP}
if runtime.GOOS != "darwin" && s.wgInterface != nil {
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
}
ports := []int{defaultPort, customPort}
for _, port := range ports {
for _, ip := range ips {
addrString := fmt.Sprintf("%s:%d", ip, port)
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
probeListener, err := net.ListenUDP("udp", udpAddr)
if err == nil {
err = probeListener.Close()
if err != nil {
log.Errorf("got an error closing the probe listener, error: %s", err)
}
return ip, port, nil
}
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
}
}
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
}
func (s *DefaultServer) setListenerStatus(running bool) {
s.listenerIsRunning = running
}
// Stop stops the server
@@ -159,23 +202,37 @@ func (s *DefaultServer) Stop() {
}
}
s.service.Stop()
err := s.stopListener()
if err != nil {
log.Error(err)
}
}
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
// It will be applied if the mgm server do not enforce DNS settings for root zone
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
s.hostsDnsListLock.Lock()
defer s.hostsDnsListLock.Unlock()
s.hostsDnsList = hostsDnsList
_, ok := s.dnsMuxMap[nbdns.RootZone]
if ok {
log.Debugf("on new host DNS config but skip to apply it")
return
func (s *DefaultServer) stopListener() error {
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
// udpFilterHookID here empty only in the unit tests
if filter := s.wgInterface.GetFilter(); filter != nil && s.udpFilterHookID != "" {
if err := filter.RemovePacketHook(s.udpFilterHookID); err != nil {
log.Errorf("unable to remove DNS packet hook: %s", err)
}
}
s.udpFilterHookID = ""
s.listenerIsRunning = false
return nil
}
log.Debugf("update host DNS settings: %+v", hostsDnsList)
s.addHostRootZone()
if !s.listenerIsRunning {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := s.server.ShutdownContext(ctx)
if err != nil {
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
}
return nil
}
// UpdateDNSServer processes an update received from the management service
@@ -223,28 +280,15 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
}
}
func (s *DefaultServer) SearchDomains() []string {
var searchDomains []string
for _, dConf := range s.currentConfig.domains {
if dConf.disabled {
continue
}
if dConf.matchOnly {
continue
}
searchDomains = append(searchDomains, dConf.domain)
}
return searchDomains
}
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// is the service should be disabled, we stop the listener or fake resolver
// and proceed with a regular update to clean up the handlers and records
if update.ServiceEnable {
_ = s.service.Listen()
} else if !s.permanent {
s.service.Stop()
if !update.ServiceEnable {
if err := s.stopListener(); err != nil {
log.Error(err)
}
} else if !s.listenerIsRunning {
s.listen()
}
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
@@ -255,16 +299,17 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err)
}
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) //nolint:gocritic
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
s.updateMux(muxUpdates)
s.updateLocalResolver(localRecords)
s.currentConfig = dnsConfigTohostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)
hostUpdate := s.currentConfig
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
if s.runtimePort != defaultPort && !s.hostManager.supportCustomPort() {
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
"Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver")
hostUpdate.routeAll = false
}
@@ -272,10 +317,6 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
log.Error(err)
}
if s.searchDomainNotifier != nil {
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
}
return nil
}
@@ -315,10 +356,10 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
continue
}
handler := newUpstreamResolver(s.ctx, s.interfaceName, s.wgAddr)
handler := newUpstreamResolver(s.ctx)
for _, ns := range nsGroup.NameServers {
if ns.NSType != nbdns.UDPNameServerType {
log.Warnf("skipping nameserver %s with type %s, this peer supports only %s",
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
continue
}
@@ -336,7 +377,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
// reapply DNS settings, but it not touch the original configuration and serial number
// because it is temporal deactivation until next try
//
// after some period defined by upstream it tries to reactivate self by calling this hook
// after some period defined by upstream it trys to reactivate self by calling this hook
// everything we need here is just to re-apply current configuration because it already
// contains this upstream settings (temporal deactivation not removed it)
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler)
@@ -371,32 +412,19 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
muxUpdateMap := make(registeredHandlerMap)
var isContainRootUpdate bool
for _, update := range muxUpdates {
s.service.RegisterMux(update.domain, update.handler)
s.registerMux(update.domain, update.handler)
muxUpdateMap[update.domain] = update.handler
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
existingHandler.stop()
}
if update.domain == nbdns.RootZone {
isContainRootUpdate = true
}
}
for key, existingHandler := range s.dnsMuxMap {
_, found := muxUpdateMap[key]
if !found {
if !isContainRootUpdate && key == nbdns.RootZone {
s.hostsDnsListLock.Lock()
s.addHostRootZone()
s.hostsDnsListLock.Unlock()
existingHandler.stop()
} else {
existingHandler.stop()
s.service.DeregisterMux(key)
}
existingHandler.stop()
s.deregisterMux(key)
}
}
@@ -427,6 +455,14 @@ func getNSHostPort(ns nbdns.NameServer) string {
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
}
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
func (s *DefaultServer) deregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}
// upstreamCallbacks returns two functions, the first one is used to deactivate
// the upstream resolver from the configuration, the second one is used to
// reactivate it. Not allowed to call reactivate before deactivate.
@@ -454,7 +490,7 @@ func (s *DefaultServer) upstreamCallbacks(
for i, item := range s.currentConfig.domains {
if _, found := removeIndex[item.domain]; found {
s.currentConfig.domains[i].disabled = true
s.service.DeregisterMux(item.domain)
s.deregisterMux(item.domain)
removeIndex[item.domain] = i
}
}
@@ -471,7 +507,7 @@ func (s *DefaultServer) upstreamCallbacks(
continue
}
s.currentConfig.domains[i].disabled = false
s.service.RegisterMux(domain, handler)
s.registerMux(domain, handler)
}
l := log.WithField("nameservers", nsGroup.NameServers)
@@ -487,24 +523,93 @@ func (s *DefaultServer) upstreamCallbacks(
return
}
func (s *DefaultServer) addHostRootZone() {
handler := newUpstreamResolver(s.ctx, s.interfaceName, s.wgAddr)
handler.upstreamServers = make([]string, len(s.hostsDnsList))
for n, ua := range s.hostsDnsList {
a, err := netip.ParseAddr(ua)
if err != nil {
log.Errorf("invalid upstream IP address: %s, error: %s", ua, err)
continue
}
ipString := ua
if !a.Is4() {
ipString = fmt.Sprintf("[%s]", ua)
}
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ipString)
func (s *DefaultServer) filterDNSTraffic() string {
filter := s.wgInterface.GetFilter()
if filter == nil {
log.Error("can't set DNS filter, filter not initialized")
return ""
}
handler.deactivate = func() {}
handler.reactivate = func() {}
s.service.RegisterMux(nbdns.RootZone, handler)
firstLayerDecoder := layers.LayerTypeIPv4
if s.wgInterface.Address().Network.IP.To4() == nil {
firstLayerDecoder = layers.LayerTypeIPv6
}
hook := func(packetData []byte) bool {
// Decode the packet
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
// Get the UDP layer
udpLayer := packet.Layer(layers.LayerTypeUDP)
udp := udpLayer.(*layers.UDP)
msg := new(dns.Msg)
if err := msg.Unpack(udp.Payload); err != nil {
log.Tracef("parse DNS request: %v", err)
return true
}
writer := responseWriter{
packet: packet,
device: s.wgInterface.GetDevice().Device,
}
go s.dnsMux.ServeDNS(&writer, msg)
return true
}
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook)
}
func (s *DefaultServer) evelRuntimeAddressForUserspace() {
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
s.runtimePort = defaultPort
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
}
func (s *DefaultServer) evalRuntimeAddress() {
defer func() {
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
}()
if s.customAddress != nil {
s.runtimeIP = s.customAddress.Addr().String()
s.runtimePort = int(s.customAddress.Port())
return
}
ip, port, err := s.getFirstListenerAvailable()
if err != nil {
log.Error(err)
return
}
s.runtimeIP = ip
s.runtimePort = port
}
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
// Calculate the last IP in the CIDR range
var endIP net.IP
for i := 0; i < len(network.IP); i++ {
endIP = append(endIP, network.IP[i]|^network.Mask[i])
}
// convert to big.Int
endInt := big.NewInt(0)
endInt.SetBytes(endIP)
// subtract fromEnd from the last ip
fromEndBig := big.NewInt(int64(fromEnd))
resultInt := big.NewInt(0)
resultInt.Sub(endInt, fromEndBig)
return net.IP(resultInt.Bytes()).String()
}
func hasValidDnsServer(cfg *nbdns.Config) bool {
for _, c := range cfg.NameServerGroups {
if c.Primary {
return true
}
}
return false
}

View File

@@ -1,10 +0,0 @@
package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) {
err = s.service.Listen()
if err != nil {
return err
}
return newHostManager()
}

View File

@@ -1,5 +0,0 @@
package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) {
return newHostManager()
}

View File

@@ -1,29 +0,0 @@
package dns
import (
"fmt"
"sync"
)
var (
mutex sync.Mutex
server Server
)
// GetServerDns export the DNS server instance in static way. It used by the Mobile client
func GetServerDns() (Server, error) {
mutex.Lock()
if server == nil {
mutex.Unlock()
return nil, fmt.Errorf("DNS server not instantiated yet")
}
s := server
mutex.Unlock()
return s, nil
}
func setServerDns(newServerServer Server) {
mutex.Lock()
server = newServerServer
defer mutex.Unlock()
}

View File

@@ -1,24 +0,0 @@
package dns
import (
"testing"
)
func TestGetServerDns(t *testing.T) {
_, err := GetServerDns()
if err == nil {
t.Errorf("invalid dns server instance")
}
srv := &MockServer{}
setServerDns(srv)
srvB, err := GetServerDns()
if err != nil {
t.Errorf("invalid dns server instance: %s", err)
}
if srvB != srv {
t.Errorf("mismatch dns instances")
}
}

View File

@@ -1,6 +0,0 @@
package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) {
// todo add ioDnsManager to constuctor
return newHostManager(m.ioDnsManager)
}

View File

@@ -1,5 +0,0 @@
package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) {
return newHostManager(s.wgInterface)
}

View File

@@ -11,53 +11,14 @@ import (
"time"
"github.com/golang/mock/gomock"
log "github.com/sirupsen/logrus"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/iface"
pfmock "github.com/netbirdio/netbird/iface/mocks"
)
type mocWGIface struct {
filter iface.PacketFilter
}
func (w *mocWGIface) Name() string {
panic("implement me")
}
func (w *mocWGIface) Address() iface.WGAddress {
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
return iface.WGAddress{
IP: ip,
Network: network,
}
}
func (w *mocWGIface) GetFilter() iface.PacketFilter {
return w.filter
}
func (w *mocWGIface) GetDevice() *iface.DeviceWrapper {
panic("implement me")
}
func (w *mocWGIface) GetInterfaceGUIDString() (string, error) {
panic("implement me")
}
func (w *mocWGIface) IsUserspaceBind() bool {
return false
}
func (w *mocWGIface) SetFilter(filter iface.PacketFilter) error {
w.filter = filter
return nil
}
var zoneRecords = []nbdns.SimpleRecord{
{
Name: "peera.netbird.cloud",
@@ -68,11 +29,6 @@ var zoneRecords = []nbdns.SimpleRecord{
},
}
func init() {
log.SetLevel(log.TraceLevel)
formatter.SetTextFormatter(log.StandardLogger())
}
func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{
{
@@ -268,11 +224,11 @@ func TestUpdateDNSServer(t *testing.T) {
t.Log(err)
}
}()
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", "", "")
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil)
if err != nil {
t.Fatal(err)
}
err = dnsServer.Initialize(nil)
err = dnsServer.Initialize()
if err != nil {
t.Fatal(err)
}
@@ -286,6 +242,8 @@ func TestUpdateDNSServer(t *testing.T) {
dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.registeredMap = testCase.initLocalMap
dnsServer.updateSerial = testCase.initSerial
// pretend we are running
dnsServer.listenerIsRunning = true
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
if err != nil {
@@ -322,9 +280,9 @@ func TestUpdateDNSServer(t *testing.T) {
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
os.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(nil)
if err != nil {
t.Errorf("create stdnet: %v", err)
@@ -339,7 +297,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
err = wgIface.Create()
if err != nil {
t.Errorf("create and init wireguard interface: %v", err)
t.Errorf("crate and init wireguard interface: %v", err)
return
}
defer func() {
@@ -358,23 +316,23 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
}
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
packetfilter.EXPECT().SetNetwork(ipNet)
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().RemovePacketHook(gomock.Any()).AnyTimes()
if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err)
return
}
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", "", "")
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil)
if err != nil {
t.Errorf("create DNS server: %v", err)
return
}
err = dnsServer.Initialize(nil)
err = dnsServer.Initialize()
if err != nil {
t.Errorf("run DNS server: %v", err)
return
@@ -463,23 +421,21 @@ func TestDNSServerStartStop(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, "", "")
if err != nil {
t.Fatalf("%v", err)
}
dnsServer := getDefaultServerWithNoHostManager(t, testCase.addrPort)
dnsServer.hostManager = newNoopHostMocker()
err = dnsServer.service.Listen()
if err != nil {
t.Fatalf("dns server is not running: %s", err)
}
dnsServer.listen()
time.Sleep(100 * time.Millisecond)
if !dnsServer.listenerIsRunning {
t.Fatal("dns server listener is not running")
}
defer dnsServer.Stop()
err = dnsServer.localResolver.registerRecord(zoneRecords[0])
err := dnsServer.localResolver.registerRecord(zoneRecords[0])
if err != nil {
t.Error(err)
}
dnsServer.service.RegisterMux("netbird.cloud", dnsServer.localResolver)
dnsServer.dnsMux.Handle("netbird.cloud", dnsServer.localResolver)
resolver := &net.Resolver{
PreferGo: true,
@@ -487,7 +443,7 @@ func TestDNSServerStartStop(t *testing.T) {
d := net.Dialer{
Timeout: time.Second * 5,
}
addr := fmt.Sprintf("%s:%d", dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
addr := fmt.Sprintf("%s:%d", dnsServer.runtimeIP, dnsServer.runtimePort)
conn, err := d.DialContext(ctx, network, addr)
if err != nil {
t.Log(err)
@@ -522,7 +478,7 @@ func TestDNSServerStartStop(t *testing.T) {
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{}
server := DefaultServer{
service: newServiceViaMemory(&mocWGIface{}),
dnsMux: dns.DefaultServeMux,
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
@@ -585,239 +541,62 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
}
}
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
}
defer wgIFace.Close()
func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultServer {
mux := dns.NewServeMux()
var dnsList []string
dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil)
err = dnsServer.Initialize(nil)
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
return
var parsedAddrPort *netip.AddrPort
if addrPort != "" {
parsed, err := netip.ParseAddrPort(addrPort)
if err != nil {
t.Fatal(err)
}
parsedAddrPort = &parsed
}
defer dnsServer.Stop()
dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"})
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
dnsServer := &dns.Server{
Net: "udp",
Handler: mux,
UDPSize: 65535,
}
ctx, cancel := context.WithCancel(context.TODO())
ds := &DefaultServer{
ctx: ctx,
ctxCancel: cancel,
server: dnsServer,
dnsMux: mux,
dnsMuxMap: make(registeredHandlerMap),
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
customAddress: parsedAddrPort,
}
ds.evalRuntimeAddress()
return ds
}
func TestDNSPermanent_updateUpstream(t *testing.T) {
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
}
defer wgIFace.Close()
dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil)
err = dnsServer.Initialize(nil)
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
return
}
defer dnsServer.Stop()
// check initial state
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
func TestGetLastIPFromNetwork(t *testing.T) {
tests := []struct {
addr string
ip string
}{
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
{"192.168.0.0/30", "192.168.0.2"},
{"192.168.0.0/16", "192.168.255.254"},
{"192.168.0.0/24", "192.168.0.254"},
}
update := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
},
Enabled: true,
Primary: true,
},
},
}
for _, tt := range tests {
_, ipnet, err := net.ParseCIDR(tt.addr)
if err != nil {
t.Errorf("Error parsing CIDR: %v", err)
return
}
err = dnsServer.UpdateDNSServer(1, update)
if err != nil {
t.Errorf("failed to update dns server: %s", err)
}
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
if err != nil {
t.Fatalf("failed resolve zone record: %v", err)
}
if ips[0] != zoneRecords[0].RData {
t.Fatalf("invalid zone record: %v", err)
}
update2 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{},
}
err = dnsServer.UpdateDNSServer(2, update2)
if err != nil {
t.Errorf("failed to update dns server: %s", err)
}
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
ips, err = resolver.LookupHost(context.Background(), zoneRecords[0].Name)
if err != nil {
t.Fatalf("failed resolve zone record: %v", err)
}
if ips[0] != zoneRecords[0].RData {
t.Fatalf("invalid zone record: %v", err)
}
}
func TestDNSPermanent_matchOnly(t *testing.T) {
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
}
defer wgIFace.Close()
dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil)
err = dnsServer.Initialize(nil)
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
return
}
defer dnsServer.Stop()
// check initial state
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
update := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
},
Domains: []string{"customdomain.com"},
Primary: false,
},
},
}
err = dnsServer.UpdateDNSServer(1, update)
if err != nil {
t.Errorf("failed to update dns server: %s", err)
}
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
if err != nil {
t.Fatalf("failed resolve zone record: %v", err)
}
if ips[0] != zoneRecords[0].RData {
t.Fatalf("invalid zone record: %v", err)
}
_, err = resolver.LookupHost(context.Background(), "customdomain.com")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
}
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
t.Helper()
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(nil)
if err != nil {
t.Fatalf("create stdnet: %v", err)
return nil, err
}
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", iface.DefaultMTU, nil, newNet)
if err != nil {
t.Fatalf("build interface wireguard: %v", err)
return nil, err
}
err = wgIface.Create()
if err != nil {
t.Fatalf("create and init wireguard interface: %v", err)
return nil, err
}
pf, err := uspfilter.Create(wgIface)
if err != nil {
t.Fatalf("failed to create uspfilter: %v", err)
return nil, err
}
err = wgIface.SetFilter(pf)
if err != nil {
t.Fatalf("set packet filter: %v", err)
return nil, err
}
return wgIface, nil
}
func newDnsResolver(ip string, port int) *net.Resolver {
return &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{
Timeout: time.Second * 3,
}
addr := fmt.Sprintf("%s:%d", ip, port)
return d.DialContext(ctx, network, addr)
},
lastIP := getLastIPFromNetwork(ipnet, 1)
if lastIP != tt.ip {
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
}
}
}

View File

@@ -1,5 +0,0 @@
package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) {
return newHostManager(s.wgInterface)
}

View File

@@ -1,18 +0,0 @@
package dns
import (
"github.com/miekg/dns"
)
const (
defaultPort = 53
)
type service interface {
Listen() error
Stop()
RegisterMux(domain string, handler dns.Handler)
DeregisterMux(key string)
RuntimePort() int
RuntimeIP() string
}

View File

@@ -1,193 +0,0 @@
package dns
import (
"context"
"fmt"
"net"
"net/netip"
"runtime"
"sync"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
)
const (
customPort = 5053
defaultIP = "127.0.0.1"
customIP = "127.0.0.153"
)
type serviceViaListener struct {
wgInterface WGIface
dnsMux *dns.ServeMux
customAddr *netip.AddrPort
server *dns.Server
listenIP string
listenPort int
listenerIsRunning bool
listenerFlagLock sync.Mutex
ebpfService ebpfMgr.Manager
}
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
mux := dns.NewServeMux()
s := &serviceViaListener{
wgInterface: wgIface,
dnsMux: mux,
customAddr: customAddr,
server: &dns.Server{
Net: "udp",
Handler: mux,
UDPSize: 65535,
},
}
return s
}
func (s *serviceViaListener) Listen() error {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if s.listenerIsRunning {
return nil
}
var err error
s.listenIP, s.listenPort, err = s.evalListenAddress()
if err != nil {
log.Errorf("failed to eval runtime address: %s", err)
return err
}
s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort)
if s.shouldApplyPortFwd() {
s.ebpfService = ebpf.GetEbpfManagerInstance()
err = s.ebpfService.LoadDNSFwd(s.listenIP, s.listenPort)
if err != nil {
log.Warnf("failed to load DNS port forwarder, custom port may not work well on some Linux operating systems: %s", err)
s.ebpfService = nil
}
}
log.Debugf("starting dns on %s", s.server.Addr)
go func() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)
err := s.server.ListenAndServe()
if err != nil {
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err)
}
}()
return nil
}
func (s *serviceViaListener) Stop() {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if !s.listenerIsRunning {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := s.server.ShutdownContext(ctx)
if err != nil {
log.Errorf("stopping dns server listener returned an error: %v", err)
}
if s.ebpfService != nil {
err = s.ebpfService.FreeDNSFwd()
if err != nil {
log.Errorf("stopping traffic forwarder returned an error: %v", err)
}
}
}
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
func (s *serviceViaListener) DeregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}
func (s *serviceViaListener) RuntimePort() int {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if s.ebpfService != nil {
return defaultPort
} else {
return s.listenPort
}
}
func (s *serviceViaListener) RuntimeIP() string {
return s.listenIP
}
func (s *serviceViaListener) setListenerStatus(running bool) {
s.listenerIsRunning = running
}
func (s *serviceViaListener) getFirstListenerAvailable() (string, int, error) {
ips := []string{defaultIP, customIP}
if runtime.GOOS != "darwin" {
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
}
ports := []int{defaultPort, customPort}
for _, port := range ports {
for _, ip := range ips {
addrString := fmt.Sprintf("%s:%d", ip, port)
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
probeListener, err := net.ListenUDP("udp", udpAddr)
if err == nil {
err = probeListener.Close()
if err != nil {
log.Errorf("got an error closing the probe listener, error: %s", err)
}
return ip, port, nil
}
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
}
}
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
}
func (s *serviceViaListener) evalListenAddress() (string, int, error) {
if s.customAddr != nil {
return s.customAddr.Addr().String(), int(s.customAddr.Port()), nil
}
return s.getFirstListenerAvailable()
}
// shouldApplyPortFwd decides whether to apply eBPF program to capture DNS traffic on port 53.
// This is needed because on some operating systems if we start a DNS server not on a default port 53, the domain name
// resolution won't work.
// So, in case we are running on Linux and picked a non-default port (53) we should fall back to the eBPF solution that will capture
// traffic on port 53 and forward it to a local DNS server running on 5053.
func (s *serviceViaListener) shouldApplyPortFwd() bool {
if runtime.GOOS != "linux" {
return false
}
if s.customAddr != nil {
return false
}
if s.listenPort == defaultPort {
return false
}
return true
}

View File

@@ -1,139 +0,0 @@
package dns
import (
"fmt"
"math/big"
"net"
"sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
type serviceViaMemory struct {
wgInterface WGIface
dnsMux *dns.ServeMux
runtimeIP string
runtimePort int
udpFilterHookID string
listenerIsRunning bool
listenerFlagLock sync.Mutex
}
func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
s := &serviceViaMemory{
wgInterface: wgIface,
dnsMux: dns.NewServeMux(),
runtimeIP: getLastIPFromNetwork(wgIface.Address().Network, 1),
runtimePort: defaultPort,
}
return s
}
func (s *serviceViaMemory) Listen() error {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if s.listenerIsRunning {
return nil
}
var err error
s.udpFilterHookID, err = s.filterDNSTraffic()
if err != nil {
return err
}
s.listenerIsRunning = true
log.Debugf("dns service listening on: %s", s.RuntimeIP())
return nil
}
func (s *serviceViaMemory) Stop() {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if !s.listenerIsRunning {
return
}
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
log.Errorf("unable to remove DNS packet hook: %s", err)
}
s.listenerIsRunning = false
}
func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
func (s *serviceViaMemory) DeregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}
func (s *serviceViaMemory) RuntimePort() int {
return s.runtimePort
}
func (s *serviceViaMemory) RuntimeIP() string {
return s.runtimeIP
}
func (s *serviceViaMemory) filterDNSTraffic() (string, error) {
filter := s.wgInterface.GetFilter()
if filter == nil {
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
}
firstLayerDecoder := layers.LayerTypeIPv4
if s.wgInterface.Address().Network.IP.To4() == nil {
firstLayerDecoder = layers.LayerTypeIPv6
}
hook := func(packetData []byte) bool {
// Decode the packet
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
// Get the UDP layer
udpLayer := packet.Layer(layers.LayerTypeUDP)
udp := udpLayer.(*layers.UDP)
msg := new(dns.Msg)
if err := msg.Unpack(udp.Payload); err != nil {
log.Tracef("parse DNS request: %v", err)
return true
}
writer := responseWriter{
packet: packet,
device: s.wgInterface.GetDevice().Device,
}
go s.dnsMux.ServeDNS(&writer, msg)
return true
}
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
}
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
// Calculate the last IP in the CIDR range
var endIP net.IP
for i := 0; i < len(network.IP); i++ {
endIP = append(endIP, network.IP[i]|^network.Mask[i])
}
// convert to big.Int
endInt := big.NewInt(0)
endInt.SetBytes(endIP)
// subtract fromEnd from the last ip
fromEndBig := big.NewInt(int64(fromEnd))
resultInt := big.NewInt(0)
resultInt.Sub(endInt, fromEndBig)
return net.IP(resultInt.Bytes()).String()
}

View File

@@ -1,31 +0,0 @@
package dns
import (
"net"
"testing"
)
func TestGetLastIPFromNetwork(t *testing.T) {
tests := []struct {
addr string
ip string
}{
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
{"192.168.0.0/30", "192.168.0.2"},
{"192.168.0.0/16", "192.168.255.254"},
{"192.168.0.0/24", "192.168.0.254"},
}
for _, tt := range tests {
_, ipnet, err := net.ParseCIDR(tt.addr)
if err != nil {
t.Errorf("Error parsing CIDR: %v", err)
return
}
lastIP := getLastIPFromNetwork(ipnet, 1)
if lastIP != tt.ip {
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
}
}
}

View File

@@ -15,6 +15,7 @@ import (
"golang.org/x/sys/unix"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
)
const (
@@ -52,7 +53,7 @@ type systemdDbusLinkDomainsInput struct {
MatchOnly bool
}
func newSystemdDbusConfigurator(wgInterface WGIface) (hostManager, error) {
func newSystemdDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
iface, err := net.InterfaceByName(wgInterface.Name())
if err != nil {
return nil, err

View File

@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"net"
"runtime"
"sync"
"sync/atomic"
"time"
@@ -36,55 +35,25 @@ type upstreamResolver struct {
mutex sync.Mutex
reactivatePeriod time.Duration
upstreamTimeout time.Duration
lIP net.IP
lNet *net.IPNet
lName string
iIndex int
deactivate func()
reactivate func()
}
func getInterfaceIndex(interfaceName string) (int, error) {
iface, err := net.InterfaceByName(interfaceName)
if err != nil {
log.Errorf("unable to get interface by name error: %s", err)
return 0, err
}
return iface.Index, nil
}
func newUpstreamResolver(parentCTX context.Context, interfaceName string, wgAddr string) *upstreamResolver {
func newUpstreamResolver(parentCTX context.Context) *upstreamResolver {
ctx, cancel := context.WithCancel(parentCTX)
// Specify the local IP address you want to bind to
localIP, localNet, err := net.ParseCIDR(wgAddr) // Should be our interface IP
if err != nil {
log.Errorf("error while parsing CIDR: %s", err)
}
index, err := getInterfaceIndex(interfaceName)
if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
}
localIFaceIndex := index // Should be our interface index
return &upstreamResolver{
ctx: ctx,
cancel: cancel,
upstreamClient: &dns.Client{},
upstreamTimeout: upstreamTimeout,
reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact,
lIP: localIP,
lNet: localNet,
iIndex: localIFaceIndex,
lName: interfaceName,
}
}
func (u *upstreamResolver) stop() {
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
log.Debugf("stoping serving DNS for upstreams %s", u.upstreamServers)
u.cancel()
}
@@ -101,57 +70,26 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
for _, upstream := range u.upstreamServers {
var (
exchangeErr error
t time.Duration
rm *dns.Msg
)
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream)
upstreamExchangeClient := &dns.Client{}
if runtime.GOOS != "ios" {
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
rm, t, exchangeErr = upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
cancel()
} else {
upstreamHost, _, err := net.SplitHostPort(upstream)
if err != nil {
log.Errorf("error while parsing upstream host: %s", err)
return
}
upstreamIP := net.ParseIP(upstreamHost)
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
upstreamExchangeClient = u.getClientPrivate()
}
rm, t, exchangeErr = upstreamExchangeClient.Exchange(r, upstream)
}
cancel()
if exchangeErr != nil {
if exchangeErr == context.DeadlineExceeded || isTimeout(exchangeErr) {
log.WithError(exchangeErr).WithField("upstream", upstream).
if err != nil {
if err == context.DeadlineExceeded || isTimeout(err) {
log.WithError(err).WithField("upstream", upstream).
Warn("got an error while connecting to upstream")
continue
}
u.failsCount.Add(1)
log.WithError(exchangeErr).WithField("upstream", upstream).
Error("got other error while querying the upstream")
return
}
if rm == nil {
log.WithError(exchangeErr).WithField("upstream", upstream).
Warn("no response from upstream")
return
}
// those checks need to be independent of each other due to memory address issues
if !rm.Response {
log.WithError(exchangeErr).WithField("upstream", upstream).
Warn("no response from upstream")
log.WithError(err).WithField("upstream", upstream).
Error("got an error while querying the upstream")
return
}
log.Tracef("took %s to query the upstream %s", t, upstream)
err := w.WriteMsg(rm)
err = w.WriteMsg(rm)
if err != nil {
log.WithError(err).Error("got an error while writing the upstream resolver response")
}
@@ -180,7 +118,6 @@ func (u *upstreamResolver) checkUpstreamFails() {
case <-u.ctx.Done():
return
default:
// todo test the deactivation logic, it seems to affect the client
log.Warnf("upstream resolving is disabled for %v", reactivatePeriod)
u.deactivate()
u.disabled = true

View File

@@ -1,44 +0,0 @@
//go:build ios
package dns
import (
"net"
"syscall"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
// getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
// This method is needed for iOS
func (u *upstreamResolver) getClientPrivate() *dns.Client {
dialer := &net.Dialer{
LocalAddr: &net.UDPAddr{
IP: u.lIP,
Port: 0, // Let the OS pick a free port
},
Timeout: upstreamTimeout,
Control: func(network, address string, c syscall.RawConn) error {
var operr error
fn := func(s uintptr) {
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, u.iIndex)
}
if err := c.Control(fn); err != nil {
return err
}
if operr != nil {
log.Errorf("error while setting socket option: %s", operr)
}
return operr
},
}
client := &dns.Client{
Dialer: dialer,
}
return client
}

View File

@@ -1,19 +0,0 @@
//go:build !ios
package dns
import (
"net"
"github.com/miekg/dns"
)
// getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
// This method is needed for iOS
func (u *upstreamResolver) getClientPrivate() *dns.Client {
dialer := &net.Dialer{}
client := &dns.Client{
Dialer: dialer,
}
return client
}

View File

@@ -49,15 +49,15 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
timeout: upstreamTimeout,
responseShouldBeNil: true,
},
// {
//{
// name: "Should Resolve CNAME Record",
// inputMSG: new(dns.Msg).SetQuestion("one.one.one.one", dns.TypeCNAME),
// },
// {
//},
//{
// name: "Should Not Write When Not Found A Record",
// inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
// responseShouldBeNil: true,
// },
//},
}
// should resolve if first upstream times out
// should not write when both fails
@@ -66,7 +66,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
resolver := newUpstreamResolver(ctx, "", "")
resolver := newUpstreamResolver(ctx)
resolver.upstreamServers = testCase.InputServers
resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX {

View File

@@ -1,14 +0,0 @@
//go:build !windows
package dns
import "github.com/netbirdio/netbird/iface"
// WGIface defines subset methods of interface required for manager
type WGIface interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
GetFilter() iface.PacketFilter
GetDevice() *iface.DeviceWrapper
}

View File

@@ -1,13 +0,0 @@
package dns
import "github.com/netbirdio/netbird/iface"
// WGIface defines subset methods of interface required for manager
type WGIface interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
GetFilter() iface.PacketFilter
GetDevice() *iface.DeviceWrapper
GetInterfaceGUIDString() (string, error)
}

View File

@@ -1,129 +0,0 @@
// Code generated by bpf2go; DO NOT EDIT.
//go:build arm64be || armbe || mips || mips64 || mips64p32 || ppc64 || s390 || s390x || sparc || sparc64
// +build arm64be armbe mips mips64 mips64p32 ppc64 s390 s390x sparc sparc64
package ebpf
import (
"bytes"
_ "embed"
"fmt"
"io"
"github.com/cilium/ebpf"
)
// loadBpf returns the embedded CollectionSpec for bpf.
func loadBpf() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_BpfBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load bpf: %w", err)
}
return spec, err
}
// loadBpfObjects loads bpf and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *bpfObjects
// *bpfPrograms
// *bpfMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := loadBpf()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// bpfSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfSpecs struct {
bpfProgramSpecs
bpfMapSpecs
}
// bpfSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct {
NbXdpProg *ebpf.ProgramSpec `ebpf:"nb_xdp_prog"`
}
// bpfMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct {
NbFeatures *ebpf.MapSpec `ebpf:"nb_features"`
NbMapDnsIp *ebpf.MapSpec `ebpf:"nb_map_dns_ip"`
NbMapDnsPort *ebpf.MapSpec `ebpf:"nb_map_dns_port"`
NbWgProxySettingsMap *ebpf.MapSpec `ebpf:"nb_wg_proxy_settings_map"`
}
// bpfObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfObjects struct {
bpfPrograms
bpfMaps
}
func (o *bpfObjects) Close() error {
return _BpfClose(
&o.bpfPrograms,
&o.bpfMaps,
)
}
// bpfMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct {
NbFeatures *ebpf.Map `ebpf:"nb_features"`
NbMapDnsIp *ebpf.Map `ebpf:"nb_map_dns_ip"`
NbMapDnsPort *ebpf.Map `ebpf:"nb_map_dns_port"`
NbWgProxySettingsMap *ebpf.Map `ebpf:"nb_wg_proxy_settings_map"`
}
func (m *bpfMaps) Close() error {
return _BpfClose(
m.NbFeatures,
m.NbMapDnsIp,
m.NbMapDnsPort,
m.NbWgProxySettingsMap,
)
}
// bpfPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct {
NbXdpProg *ebpf.Program `ebpf:"nb_xdp_prog"`
}
func (p *bpfPrograms) Close() error {
return _BpfClose(
p.NbXdpProg,
)
}
func _BpfClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//
//go:embed bpf_bpfeb.o
var _BpfBytes []byte

View File

@@ -1,129 +0,0 @@
// Code generated by bpf2go; DO NOT EDIT.
//go:build 386 || amd64 || amd64p32 || arm || arm64 || mips64le || mips64p32le || mipsle || ppc64le || riscv64
// +build 386 amd64 amd64p32 arm arm64 mips64le mips64p32le mipsle ppc64le riscv64
package ebpf
import (
"bytes"
_ "embed"
"fmt"
"io"
"github.com/cilium/ebpf"
)
// loadBpf returns the embedded CollectionSpec for bpf.
func loadBpf() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_BpfBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load bpf: %w", err)
}
return spec, err
}
// loadBpfObjects loads bpf and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *bpfObjects
// *bpfPrograms
// *bpfMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := loadBpf()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// bpfSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfSpecs struct {
bpfProgramSpecs
bpfMapSpecs
}
// bpfSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct {
NbXdpProg *ebpf.ProgramSpec `ebpf:"nb_xdp_prog"`
}
// bpfMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct {
NbFeatures *ebpf.MapSpec `ebpf:"nb_features"`
NbMapDnsIp *ebpf.MapSpec `ebpf:"nb_map_dns_ip"`
NbMapDnsPort *ebpf.MapSpec `ebpf:"nb_map_dns_port"`
NbWgProxySettingsMap *ebpf.MapSpec `ebpf:"nb_wg_proxy_settings_map"`
}
// bpfObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfObjects struct {
bpfPrograms
bpfMaps
}
func (o *bpfObjects) Close() error {
return _BpfClose(
&o.bpfPrograms,
&o.bpfMaps,
)
}
// bpfMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct {
NbFeatures *ebpf.Map `ebpf:"nb_features"`
NbMapDnsIp *ebpf.Map `ebpf:"nb_map_dns_ip"`
NbMapDnsPort *ebpf.Map `ebpf:"nb_map_dns_port"`
NbWgProxySettingsMap *ebpf.Map `ebpf:"nb_wg_proxy_settings_map"`
}
func (m *bpfMaps) Close() error {
return _BpfClose(
m.NbFeatures,
m.NbMapDnsIp,
m.NbMapDnsPort,
m.NbWgProxySettingsMap,
)
}
// bpfPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct {
NbXdpProg *ebpf.Program `ebpf:"nb_xdp_prog"`
}
func (p *bpfPrograms) Close() error {
return _BpfClose(
p.NbXdpProg,
)
}
func _BpfClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//
//go:embed bpf_bpfel.o
var _BpfBytes []byte

View File

@@ -1,51 +0,0 @@
package ebpf
import (
"encoding/binary"
"net"
log "github.com/sirupsen/logrus"
)
const (
mapKeyDNSIP uint32 = 0
mapKeyDNSPort uint32 = 1
)
func (tf *GeneralManager) LoadDNSFwd(ip string, dnsPort int) error {
log.Debugf("load ebpf DNS forwarder: address: %s:%d", ip, dnsPort)
tf.lock.Lock()
defer tf.lock.Unlock()
err := tf.loadXdp()
if err != nil {
return err
}
err = tf.bpfObjs.NbMapDnsIp.Put(mapKeyDNSIP, ip2int(ip))
if err != nil {
return err
}
err = tf.bpfObjs.NbMapDnsPort.Put(mapKeyDNSPort, uint16(dnsPort))
if err != nil {
return err
}
tf.setFeatureFlag(featureFlagDnsForwarder)
err = tf.bpfObjs.NbFeatures.Put(mapKeyFeatures, tf.featureFlags)
if err != nil {
return err
}
return nil
}
func (tf *GeneralManager) FreeDNSFwd() error {
log.Debugf("free ebpf DNS forwarder")
return tf.unsetFeatureFlag(featureFlagDnsForwarder)
}
func ip2int(ipString string) uint32 {
ip := net.ParseIP(ipString)
return binary.BigEndian.Uint32(ip.To4())
}

View File

@@ -1,116 +0,0 @@
package ebpf
import (
_ "embed"
"net"
"sync"
"github.com/cilium/ebpf/link"
"github.com/cilium/ebpf/rlimit"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/ebpf/manager"
)
const (
mapKeyFeatures uint32 = 0
featureFlagWGProxy = 0b00000001
featureFlagDnsForwarder = 0b00000010
)
var (
singleton manager.Manager
singletonLock = &sync.Mutex{}
)
// required packages libbpf-dev, libc6-dev-i386-amd64-cross
// GeneralManager is used to load multiple eBPF programs with a custom check (if then) done in prog.c
// The manager simply adds a feature (byte) of each program to a map that is shared between the userspace and kernel.
// When packet arrives, the C code checks for each feature (if it is set) and executes each enabled program (e.g., dns_fwd.c and wg_proxy.c).
//
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang-14 bpf src/prog.c -- -I /usr/x86_64-linux-gnu/include
type GeneralManager struct {
lock sync.Mutex
link link.Link
featureFlags uint16
bpfObjs bpfObjects
}
// GetEbpfManagerInstance return a static eBpf Manager instance
func GetEbpfManagerInstance() manager.Manager {
singletonLock.Lock()
defer singletonLock.Unlock()
if singleton != nil {
return singleton
}
singleton = &GeneralManager{}
return singleton
}
func (tf *GeneralManager) setFeatureFlag(feature uint16) {
tf.featureFlags |= feature
}
func (tf *GeneralManager) loadXdp() error {
if tf.link != nil {
return nil
}
// it required for Docker
err := rlimit.RemoveMemlock()
if err != nil {
return err
}
iFace, err := net.InterfaceByName("lo")
if err != nil {
return err
}
// load pre-compiled programs into the kernel.
err = loadBpfObjects(&tf.bpfObjs, nil)
if err != nil {
return err
}
tf.link, err = link.AttachXDP(link.XDPOptions{
Program: tf.bpfObjs.NbXdpProg,
Interface: iFace.Index,
})
if err != nil {
_ = tf.bpfObjs.Close()
tf.link = nil
return err
}
return nil
}
func (tf *GeneralManager) unsetFeatureFlag(feature uint16) error {
tf.lock.Lock()
defer tf.lock.Unlock()
tf.featureFlags &^= feature
if tf.link == nil {
return nil
}
if tf.featureFlags == 0 {
return tf.close()
}
return tf.bpfObjs.NbFeatures.Put(mapKeyFeatures, tf.featureFlags)
}
func (tf *GeneralManager) close() error {
log.Debugf("detach ebpf program ")
err := tf.bpfObjs.Close()
if err != nil {
log.Warnf("failed to close eBpf objects: %s", err)
}
err = tf.link.Close()
tf.link = nil
return err
}

View File

@@ -1,40 +0,0 @@
package ebpf
import (
"testing"
)
func TestManager_setFeatureFlag(t *testing.T) {
mgr := GeneralManager{}
mgr.setFeatureFlag(featureFlagWGProxy)
if mgr.featureFlags != 1 {
t.Errorf("invalid feature state")
}
mgr.setFeatureFlag(featureFlagDnsForwarder)
if mgr.featureFlags != 3 {
t.Errorf("invalid feature state")
}
}
func TestManager_unsetFeatureFlag(t *testing.T) {
mgr := GeneralManager{}
mgr.setFeatureFlag(featureFlagWGProxy)
mgr.setFeatureFlag(featureFlagDnsForwarder)
err := mgr.unsetFeatureFlag(featureFlagWGProxy)
if err != nil {
t.Errorf("unexpected error: %s", err)
}
if mgr.featureFlags != 2 {
t.Errorf("invalid feature state, expected: %d, got: %d", 2, mgr.featureFlags)
}
err = mgr.unsetFeatureFlag(featureFlagDnsForwarder)
if err != nil {
t.Errorf("unexpected error: %s", err)
}
if mgr.featureFlags != 0 {
t.Errorf("invalid feature state, expected: %d, got: %d", 0, mgr.featureFlags)
}
}

Some files were not shown because too many files have changed in this diff Show More