mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:19 -04:00
[management, reverse proxy] Add reverse proxy feature (#5291)
* implement reverse proxy --------- Co-authored-by: Alisdair MacLeod <git@alisdairmacleod.co.uk> Co-authored-by: mlsmaycon <mlsmaycon@gmail.com> Co-authored-by: Eduard Gert <kontakt@eduardgert.de> Co-authored-by: Viktor Liu <viktor@netbird.io> Co-authored-by: Diego Noguês <diego.sure@gmail.com> Co-authored-by: Diego Noguês <49420+diegocn@users.noreply.github.com> Co-authored-by: Bethuel Mmbaga <bethuelmbaga12@gmail.com> Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com> Co-authored-by: Ashley Mensah <ashleyamo982@gmail.com>
This commit is contained in:
6
.dockerignore
Normal file
6
.dockerignore
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
.env
|
||||||
|
.env.*
|
||||||
|
*.pem
|
||||||
|
*.key
|
||||||
|
*.crt
|
||||||
|
*.p12
|
||||||
10
.github/workflows/check-license-dependencies.yml
vendored
10
.github/workflows/check-license-dependencies.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Check for problematic license dependencies
|
- name: Check for problematic license dependencies
|
||||||
run: |
|
run: |
|
||||||
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
echo "Checking for dependencies on management/, signal/, relay/, and proxy/ packages..."
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# Find all directories except the problematic ones and system dirs
|
# Find all directories except the problematic ones and system dirs
|
||||||
@@ -31,7 +31,7 @@ jobs:
|
|||||||
while IFS= read -r dir; do
|
while IFS= read -r dir; do
|
||||||
echo "=== Checking $dir ==="
|
echo "=== Checking $dir ==="
|
||||||
# Search for problematic imports, excluding test files
|
# Search for problematic imports, excluding test files
|
||||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||||
if [ -n "$RESULTS" ]; then
|
if [ -n "$RESULTS" ]; then
|
||||||
echo "❌ Found problematic dependencies:"
|
echo "❌ Found problematic dependencies:"
|
||||||
echo "$RESULTS"
|
echo "$RESULTS"
|
||||||
@@ -39,11 +39,11 @@ jobs:
|
|||||||
else
|
else
|
||||||
echo "✓ No problematic dependencies found"
|
echo "✓ No problematic dependencies found"
|
||||||
fi
|
fi
|
||||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
|
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name "combined" -not -name ".git*" | sort)
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
echo "❌ Found dependencies on management/, signal/, relay/, or proxy/ packages"
|
||||||
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
||||||
exit 1
|
exit 1
|
||||||
else
|
else
|
||||||
@@ -88,7 +88,7 @@ jobs:
|
|||||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||||
|
|
||||||
# Check if any importer is NOT in management/signal/relay
|
# Check if any importer is NOT in management/signal/relay
|
||||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
|
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\)" | head -1)
|
||||||
|
|
||||||
if [ -n "$BSD_IMPORTER" ]; then
|
if [ -n "$BSD_IMPORTER" ]; then
|
||||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||||
|
|||||||
2
.github/workflows/golang-test-darwin.yml
vendored
2
.github/workflows/golang-test-darwin.yml
vendored
@@ -43,5 +43,5 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||||
|
|
||||||
|
|||||||
1
.github/workflows/golang-test-freebsd.yml
vendored
1
.github/workflows/golang-test-freebsd.yml
vendored
@@ -46,6 +46,5 @@ jobs:
|
|||||||
time go test -timeout 1m -failfast ./client/iface/...
|
time go test -timeout 1m -failfast ./client/iface/...
|
||||||
time go test -timeout 1m -failfast ./route/...
|
time go test -timeout 1m -failfast ./route/...
|
||||||
time go test -timeout 1m -failfast ./sharedsock/...
|
time go test -timeout 1m -failfast ./sharedsock/...
|
||||||
time go test -timeout 1m -failfast ./signal/...
|
|
||||||
time go test -timeout 1m -failfast ./util/...
|
time go test -timeout 1m -failfast ./util/...
|
||||||
time go test -timeout 1m -failfast ./version/...
|
time go test -timeout 1m -failfast ./version/...
|
||||||
|
|||||||
61
.github/workflows/golang-test-linux.yml
vendored
61
.github/workflows/golang-test-linux.yml
vendored
@@ -97,6 +97,16 @@ jobs:
|
|||||||
working-directory: relay
|
working-directory: relay
|
||||||
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
|
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
|
||||||
|
|
||||||
|
- name: Build combined
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
working-directory: combined
|
||||||
|
run: CGO_ENABLED=1 go build .
|
||||||
|
|
||||||
|
- name: Build combined 386
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
working-directory: combined
|
||||||
|
run: CGO_ENABLED=1 GOARCH=386 go build -o combined-386 .
|
||||||
|
|
||||||
test:
|
test:
|
||||||
name: "Client / Unit"
|
name: "Client / Unit"
|
||||||
needs: [build-cache]
|
needs: [build-cache]
|
||||||
@@ -144,7 +154,7 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||||
|
|
||||||
test_client_on_docker:
|
test_client_on_docker:
|
||||||
name: "Client (Docker) / Unit"
|
name: "Client (Docker) / Unit"
|
||||||
@@ -204,7 +214,7 @@ jobs:
|
|||||||
sh -c ' \
|
sh -c ' \
|
||||||
apk update; apk add --no-cache \
|
apk update; apk add --no-cache \
|
||||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
|
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
||||||
'
|
'
|
||||||
|
|
||||||
test_relay:
|
test_relay:
|
||||||
@@ -261,6 +271,53 @@ jobs:
|
|||||||
-exec 'sudo' \
|
-exec 'sudo' \
|
||||||
-timeout 10m -p 1 ./relay/... ./shared/relay/...
|
-timeout 10m -p 1 ./relay/... ./shared/relay/...
|
||||||
|
|
||||||
|
test_proxy:
|
||||||
|
name: "Proxy / Unit"
|
||||||
|
needs: [build-cache]
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ '386','amd64' ]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: "go.mod"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: |
|
||||||
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
|
go test -timeout 10m -p 1 ./proxy/...
|
||||||
|
|
||||||
test_signal:
|
test_signal:
|
||||||
name: "Signal / Unit"
|
name: "Signal / Unit"
|
||||||
needs: [build-cache]
|
needs: [build-cache]
|
||||||
|
|||||||
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
|||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
||||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV
|
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' })" >> $env:GITHUB_ENV
|
||||||
|
|
||||||
- name: test
|
- name: test
|
||||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||||
|
|||||||
4
.github/workflows/golangci-lint.yml
vendored
4
.github/workflows/golangci-lint.yml
vendored
@@ -19,8 +19,8 @@ jobs:
|
|||||||
- name: codespell
|
- name: codespell
|
||||||
uses: codespell-project/actions-codespell@v2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver
|
||||||
skip: go.mod,go.sum
|
skip: go.mod,go.sum,**/proxy/web/**
|
||||||
golangci:
|
golangci:
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|||||||
16
.github/workflows/release.yml
vendored
16
.github/workflows/release.yml
vendored
@@ -160,7 +160,7 @@ jobs:
|
|||||||
username: ${{ secrets.DOCKER_USER }}
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
password: ${{ secrets.DOCKER_TOKEN }}
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
- name: Log in to the GitHub container registry
|
- name: Log in to the GitHub container registry
|
||||||
if: github.event_name != 'pull_request'
|
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
registry: ghcr.io
|
registry: ghcr.io
|
||||||
@@ -176,6 +176,7 @@ jobs:
|
|||||||
- name: Generate windows syso arm64
|
- name: Generate windows syso arm64
|
||||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||||
- name: Run GoReleaser
|
- name: Run GoReleaser
|
||||||
|
id: goreleaser
|
||||||
uses: goreleaser/goreleaser-action@v4
|
uses: goreleaser/goreleaser-action@v4
|
||||||
with:
|
with:
|
||||||
version: ${{ env.GORELEASER_VER }}
|
version: ${{ env.GORELEASER_VER }}
|
||||||
@@ -185,6 +186,19 @@ jobs:
|
|||||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||||
|
- name: Tag and push PR images (amd64 only)
|
||||||
|
if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository
|
||||||
|
run: |
|
||||||
|
PR_TAG="pr-${{ github.event.pull_request.number }}"
|
||||||
|
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
|
||||||
|
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
|
||||||
|
grep '^ghcr.io/' | while read -r SRC; do
|
||||||
|
IMG_NAME="${SRC%%:*}"
|
||||||
|
DST="${IMG_NAME}:${PR_TAG}"
|
||||||
|
echo "Tagging ${SRC} -> ${DST}"
|
||||||
|
docker tag "$SRC" "$DST"
|
||||||
|
docker push "$DST"
|
||||||
|
done
|
||||||
- name: upload non tags for debug purposes
|
- name: upload non tags for debug purposes
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -2,6 +2,7 @@
|
|||||||
.run
|
.run
|
||||||
*.iml
|
*.iml
|
||||||
dist/
|
dist/
|
||||||
|
!proxy/web/dist/
|
||||||
bin/
|
bin/
|
||||||
.env
|
.env
|
||||||
conf.json
|
conf.json
|
||||||
|
|||||||
@@ -140,6 +140,20 @@ builds:
|
|||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
|
- id: netbird-proxy
|
||||||
|
dir: proxy/cmd/proxy
|
||||||
|
env: [CGO_ENABLED=0]
|
||||||
|
binary: netbird-proxy
|
||||||
|
goos:
|
||||||
|
- linux
|
||||||
|
goarch:
|
||||||
|
- amd64
|
||||||
|
- arm64
|
||||||
|
- arm
|
||||||
|
ldflags:
|
||||||
|
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
|
||||||
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
universal_binaries:
|
universal_binaries:
|
||||||
- id: netbird
|
- id: netbird
|
||||||
|
|
||||||
@@ -589,6 +603,55 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||||
|
ids:
|
||||||
|
- netbird-proxy
|
||||||
|
goarch: amd64
|
||||||
|
use: buildx
|
||||||
|
dockerfile: proxy/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/amd64"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||||
|
ids:
|
||||||
|
- netbird-proxy
|
||||||
|
goarch: arm64
|
||||||
|
use: buildx
|
||||||
|
dockerfile: proxy/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm64"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||||
|
ids:
|
||||||
|
- netbird-proxy
|
||||||
|
goarch: arm
|
||||||
|
goarm: 6
|
||||||
|
use: buildx
|
||||||
|
dockerfile: proxy/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
docker_manifests:
|
docker_manifests:
|
||||||
- name_template: netbirdio/netbird:{{ .Version }}
|
- name_template: netbirdio/netbird:{{ .Version }}
|
||||||
image_templates:
|
image_templates:
|
||||||
@@ -769,6 +832,30 @@ docker_manifests:
|
|||||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: netbirdio/reverse-proxy:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: netbirdio/reverse-proxy:latest
|
||||||
|
image_templates:
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/reverse-proxy:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/reverse-proxy:latest
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||||
|
|
||||||
brews:
|
brews:
|
||||||
- ids:
|
- ids:
|
||||||
- default
|
- default
|
||||||
|
|||||||
2
LICENSE
2
LICENSE
@@ -1,4 +1,4 @@
|
|||||||
This BSD‑3‑Clause license applies to all parts of the repository except for the directories management/, signal/ and relay/.
|
This BSD‑3‑Clause license applies to all parts of the repository except for the directories management/, signal/, relay/ and combined/.
|
||||||
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
|
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
|
||||||
|
|
||||||
BSD 3-Clause License
|
BSD 3-Clause License
|
||||||
|
|||||||
@@ -31,6 +31,14 @@ var (
|
|||||||
ErrConfigNotInitialized = errors.New("config not initialized")
|
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// PeerConnStatus is a peer's connection status.
|
||||||
|
type PeerConnStatus = peer.ConnStatus
|
||||||
|
|
||||||
|
const (
|
||||||
|
// PeerStatusConnected indicates the peer is in connected state.
|
||||||
|
PeerStatusConnected = peer.StatusConnected
|
||||||
|
)
|
||||||
|
|
||||||
// Client manages a netbird embedded client instance.
|
// Client manages a netbird embedded client instance.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
deviceName string
|
deviceName string
|
||||||
@@ -162,6 +170,7 @@ func New(opts Options) (*Client, error) {
|
|||||||
setupKey: opts.SetupKey,
|
setupKey: opts.SetupKey,
|
||||||
jwtToken: opts.JWTToken,
|
jwtToken: opts.JWTToken,
|
||||||
config: config,
|
config: config,
|
||||||
|
recorder: peer.NewRecorder(config.ManagementURL.String()),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -183,6 +192,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
|
|
||||||
// nolint:staticcheck
|
// nolint:staticcheck
|
||||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||||
|
|
||||||
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
|
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create auth client: %w", err)
|
return fmt.Errorf("create auth client: %w", err)
|
||||||
@@ -192,10 +202,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
||||||
return fmt.Errorf("login: %w", err)
|
return fmt.Errorf("login: %w", err)
|
||||||
}
|
}
|
||||||
|
client := internal.NewConnectClient(ctx, c.config, c.recorder, false)
|
||||||
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
|
||||||
c.recorder = recorder
|
|
||||||
client := internal.NewConnectClient(ctx, c.config, recorder, false)
|
|
||||||
client.SetSyncResponsePersistence(true)
|
client.SetSyncResponsePersistence(true)
|
||||||
|
|
||||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||||
@@ -348,14 +355,9 @@ func (c *Client) NewHTTPClient() *http.Client {
|
|||||||
// Status returns the current status of the client.
|
// Status returns the current status of the client.
|
||||||
func (c *Client) Status() (peer.FullStatus, error) {
|
func (c *Client) Status() (peer.FullStatus, error) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
recorder := c.recorder
|
|
||||||
connect := c.connect
|
connect := c.connect
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
|
||||||
if recorder == nil {
|
|
||||||
return peer.FullStatus{}, errors.New("client not started")
|
|
||||||
}
|
|
||||||
|
|
||||||
if connect != nil {
|
if connect != nil {
|
||||||
engine := connect.Engine()
|
engine := connect.Engine()
|
||||||
if engine != nil {
|
if engine != nil {
|
||||||
@@ -363,7 +365,7 @@ func (c *Client) Status() (peer.FullStatus, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return recorder.GetFullStatus(), nil
|
return c.recorder.GetFullStatus(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLatestSyncResponse returns the latest sync response from the management server.
|
// GetLatestSyncResponse returns the latest sync response from the management server.
|
||||||
|
|||||||
@@ -115,6 +115,17 @@ func (t *TCPConnTrack) IsTombstone() bool {
|
|||||||
return t.tombstone.Load()
|
return t.tombstone.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsSupersededBy returns true if this connection should be replaced by a new one
|
||||||
|
// carrying the given flags. Tombstoned connections are always superseded; TIME-WAIT
|
||||||
|
// connections are superseded by a pure SYN (a new connection attempt for the same
|
||||||
|
// four-tuple, as contemplated by RFC 1122 §4.2.2.13 and RFC 6191).
|
||||||
|
func (t *TCPConnTrack) IsSupersededBy(flags uint8) bool {
|
||||||
|
if t.tombstone.Load() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return flags&TCPSyn != 0 && flags&TCPAck == 0 && TCPState(t.state.Load()) == TCPStateTimeWait
|
||||||
|
}
|
||||||
|
|
||||||
// SetTombstone safely marks the connection for deletion
|
// SetTombstone safely marks the connection for deletion
|
||||||
func (t *TCPConnTrack) SetTombstone() {
|
func (t *TCPConnTrack) SetTombstone() {
|
||||||
t.tombstone.Store(true)
|
t.tombstone.Store(true)
|
||||||
@@ -169,7 +180,7 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
|||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if exists {
|
if exists && !conn.IsSupersededBy(flags) {
|
||||||
t.updateState(key, conn, flags, direction, size)
|
t.updateState(key, conn, flags, direction, size)
|
||||||
return key, uint16(conn.DNATOrigPort.Load()), true
|
return key, uint16(conn.DNATOrigPort.Load()), true
|
||||||
}
|
}
|
||||||
@@ -241,7 +252,7 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
|||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if !exists || conn.IsTombstone() {
|
if !exists || conn.IsSupersededBy(flags) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -485,6 +485,261 @@ func TestTCPAbnormalSequences(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestTCPPortReuseTombstone verifies that a new connection on a port with a
|
||||||
|
// tombstoned (closed) conntrack entry is properly tracked. Without the fix,
|
||||||
|
// updateIfExists treats tombstoned entries as live, causing track() to skip
|
||||||
|
// creating a new connection. The subsequent SYN-ACK then fails IsValidInbound
|
||||||
|
// because the entry is tombstoned, and the response packet gets dropped by ACL.
|
||||||
|
func TestTCPPortReuseTombstone(t *testing.T) {
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
t.Run("Outbound port reuse after graceful close", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish and gracefully close a connection (server-initiated close)
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Server sends FIN
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
// Client sends FIN-ACK
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
|
||||||
|
// Server sends final ACK
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
// Connection should be tombstoned
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn, "old connection should still be in map")
|
||||||
|
require.True(t, conn.IsTombstone(), "old connection should be tombstoned")
|
||||||
|
|
||||||
|
// Now reuse the same port for a new connection
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||||
|
|
||||||
|
// The old tombstoned entry should be replaced with a new one
|
||||||
|
newConn := tracker.connections[key]
|
||||||
|
require.NotNil(t, newConn, "new connection should exist")
|
||||||
|
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
|
||||||
|
require.Equal(t, TCPStateSynSent, newConn.GetState())
|
||||||
|
|
||||||
|
// SYN-ACK for the new connection should be valid
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||||
|
require.True(t, valid, "SYN-ACK for new connection on reused port should be accepted")
|
||||||
|
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||||
|
|
||||||
|
// Data transfer should work
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 500)
|
||||||
|
require.True(t, valid, "data should be allowed on new connection")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Outbound port reuse after RST", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish and RST a connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.True(t, conn.IsTombstone(), "RST connection should be tombstoned")
|
||||||
|
|
||||||
|
// Reuse the same port
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||||
|
|
||||||
|
newConn := tracker.connections[key]
|
||||||
|
require.NotNil(t, newConn)
|
||||||
|
require.False(t, newConn.IsTombstone())
|
||||||
|
require.Equal(t, TCPStateSynSent, newConn.GetState())
|
||||||
|
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||||
|
require.True(t, valid, "SYN-ACK should be accepted after RST tombstone")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Inbound port reuse after close", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
clientIP := srcIP
|
||||||
|
serverIP := dstIP
|
||||||
|
clientPort := srcPort
|
||||||
|
serverPort := dstPort
|
||||||
|
key := ConnKey{SrcIP: clientIP, DstIP: serverIP, SrcPort: clientPort, DstPort: serverPort}
|
||||||
|
|
||||||
|
// Inbound connection: client SYN → server SYN-ACK → client ACK
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||||
|
|
||||||
|
// Server-initiated close to reach Closed/tombstoned:
|
||||||
|
// Server FIN (opposite dir) → CloseWait
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPFin|TCPAck, 100)
|
||||||
|
require.Equal(t, TCPStateCloseWait, conn.GetState())
|
||||||
|
// Client FIN-ACK (same dir as conn) → LastAck
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPFin|TCPAck, nil, 100, 0)
|
||||||
|
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||||
|
// Server final ACK (opposite dir) → Closed → tombstoned
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
|
||||||
|
|
||||||
|
require.True(t, conn.IsTombstone())
|
||||||
|
|
||||||
|
// New inbound connection on same ports
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
|
||||||
|
|
||||||
|
newConn := tracker.connections[key]
|
||||||
|
require.NotNil(t, newConn)
|
||||||
|
require.False(t, newConn.IsTombstone())
|
||||||
|
require.Equal(t, TCPStateSynReceived, newConn.GetState())
|
||||||
|
|
||||||
|
// Complete handshake: server SYN-ACK, then client ACK
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||||
|
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Late ACK on tombstoned connection is harmless", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish and close via passive close (server-initiated FIN → Closed → tombstoned)
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) // CloseWait
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // LastAck
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) // Closed
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.True(t, conn.IsTombstone())
|
||||||
|
|
||||||
|
// Late ACK should be rejected (tombstoned)
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.False(t, valid, "late ACK on tombstoned connection should be rejected")
|
||||||
|
|
||||||
|
// Late outbound ACK should not create a new connection (not a SYN)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
require.True(t, tracker.connections[key].IsTombstone(), "late outbound ACK should not replace tombstoned entry")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPPortReuseTimeWait(t *testing.T) {
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
t.Run("Outbound port reuse during TIME-WAIT (active close)", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Active close: client (outbound initiator) sends FIN first
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||||
|
|
||||||
|
// Server ACKs the FIN
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||||
|
|
||||||
|
// Server sends its own FIN
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
|
||||||
|
// Client sends final ACK (TIME-WAIT stays, not tombstoned)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
require.False(t, conn.IsTombstone(), "TIME-WAIT should not be tombstoned")
|
||||||
|
|
||||||
|
// New outbound SYN on the same port (port reuse during TIME-WAIT)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||||
|
|
||||||
|
// Per RFC 1122/6191, new SYN during TIME-WAIT should start a new connection
|
||||||
|
newConn := tracker.connections[key]
|
||||||
|
require.NotNil(t, newConn, "new connection should exist")
|
||||||
|
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
|
||||||
|
require.Equal(t, TCPStateSynSent, newConn.GetState(), "new connection should be in SYN-SENT")
|
||||||
|
|
||||||
|
// SYN-ACK for new connection should be valid
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||||
|
require.True(t, valid, "SYN-ACK for new connection should be accepted")
|
||||||
|
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Inbound SYN during TIME-WAIT falls through to normal tracking", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish outbound connection and close via active close → TIME-WAIT
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
|
||||||
|
// Inbound SYN on same ports during TIME-WAIT: IsValidInbound returns false
|
||||||
|
// so the filter falls through to ACL check + TrackInbound (which creates
|
||||||
|
// a new connection via track() → updateIfExists skips TIME-WAIT for SYN)
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, 0)
|
||||||
|
require.False(t, valid, "inbound SYN during TIME-WAIT should fail conntrack validation")
|
||||||
|
|
||||||
|
// Simulate what the filter does next: TrackInbound via the normal path
|
||||||
|
tracker.TrackInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, nil, 100, 0)
|
||||||
|
|
||||||
|
// The new inbound connection uses the inverted key (dst→src becomes src→dst in track)
|
||||||
|
invertedKey := ConnKey{SrcIP: dstIP, DstIP: srcIP, SrcPort: dstPort, DstPort: srcPort}
|
||||||
|
newConn := tracker.connections[invertedKey]
|
||||||
|
require.NotNil(t, newConn, "new inbound connection should be tracked")
|
||||||
|
require.Equal(t, TCPStateSynReceived, newConn.GetState())
|
||||||
|
require.False(t, newConn.IsTombstone())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Late retransmit during TIME-WAIT still allowed", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish and active close → TIME-WAIT
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
|
||||||
|
// Late ACK retransmits during TIME-WAIT should still be accepted
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid, "retransmitted ACK during TIME-WAIT should be accepted")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestTCPTimeoutHandling(t *testing.T) {
|
func TestTCPTimeoutHandling(t *testing.T) {
|
||||||
// Create tracker with a very short timeout for testing
|
// Create tracker with a very short timeout for testing
|
||||||
shortTimeout := 100 * time.Millisecond
|
shortTimeout := 100 * time.Millisecond
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -16,9 +18,18 @@ const (
|
|||||||
maxBatchSize = 1024 * 16
|
maxBatchSize = 1024 * 16
|
||||||
maxMessageSize = 1024 * 2
|
maxMessageSize = 1024 * 2
|
||||||
defaultFlushInterval = 2 * time.Second
|
defaultFlushInterval = 2 * time.Second
|
||||||
logChannelSize = 1000
|
defaultLogChanSize = 1000
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func getLogChannelSize() int {
|
||||||
|
if v := os.Getenv("NB_USPFILTER_LOG_BUFFER"); v != "" {
|
||||||
|
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultLogChanSize
|
||||||
|
}
|
||||||
|
|
||||||
type Level uint32
|
type Level uint32
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -69,7 +80,7 @@ type Logger struct {
|
|||||||
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||||
l := &Logger{
|
l := &Logger{
|
||||||
output: logrusLogger.Out,
|
output: logrusLogger.Out,
|
||||||
msgChannel: make(chan logMessage, logChannelSize),
|
msgChannel: make(chan logMessage, getLogChannelSize()),
|
||||||
shutdown: make(chan struct{}),
|
shutdown: make(chan struct{}),
|
||||||
bufPool: sync.Pool{
|
bufPool: sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
@@ -1923,7 +1924,7 @@ func (e *Engine) triggerClientRestart() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) startNetworkMonitor() {
|
func (e *Engine) startNetworkMonitor() {
|
||||||
if !e.config.NetworkMonitor {
|
if !e.config.NetworkMonitor || nbnetstack.IsEnabled() {
|
||||||
log.Infof("Network monitor is disabled, not starting")
|
log.Infof("Network monitor is disabled, not starting")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -38,11 +37,6 @@ func New() *NetworkMonitor {
|
|||||||
|
|
||||||
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
|
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
|
||||||
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
|
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
|
||||||
if netstack.IsEnabled() {
|
|
||||||
log.Debugf("Network monitor: skipping in netstack mode")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
nw.mu.Lock()
|
nw.mu.Lock()
|
||||||
if nw.cancel != nil {
|
if nw.cancel != nil {
|
||||||
nw.mu.Unlock()
|
nw.mu.Unlock()
|
||||||
|
|||||||
25
combined/Dockerfile.multistage
Normal file
25
combined/Dockerfile.multistage
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
FROM golang:1.25-bookworm AS builder
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install build dependencies
|
||||||
|
RUN apt-get update && apt-get install -y gcc libc6-dev git && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
COPY go.mod go.sum ./
|
||||||
|
RUN go mod download
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
# Build with version info from git (matching goreleaser ldflags)
|
||||||
|
RUN CGO_ENABLED=1 GOOS=linux go build \
|
||||||
|
-ldflags="-s -w \
|
||||||
|
-X github.com/netbirdio/netbird/version.version=$(git describe --tags --always --dirty 2>/dev/null || echo 'dev') \
|
||||||
|
-X main.commit=$(git rev-parse --short HEAD 2>/dev/null || echo 'unknown') \
|
||||||
|
-X main.date=$(date -u +%Y-%m-%dT%H:%M:%SZ) \
|
||||||
|
-X main.builtBy=docker" \
|
||||||
|
-o netbird-server ./combined
|
||||||
|
|
||||||
|
FROM ubuntu:24.04
|
||||||
|
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
|
||||||
|
ENTRYPOINT [ "/go/bin/netbird-server" ]
|
||||||
|
CMD ["--config", "/etc/netbird/config.yaml"]
|
||||||
|
COPY --from=builder /app/netbird-server /go/bin/netbird-server
|
||||||
661
combined/LICENSE
Normal file
661
combined/LICENSE
Normal file
@@ -0,0 +1,661 @@
|
|||||||
|
GNU AFFERO GENERAL PUBLIC LICENSE
|
||||||
|
Version 3, 19 November 2007
|
||||||
|
|
||||||
|
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||||
|
Everyone is permitted to copy and distribute verbatim copies
|
||||||
|
of this license document, but changing it is not allowed.
|
||||||
|
|
||||||
|
Preamble
|
||||||
|
|
||||||
|
The GNU Affero General Public License is a free, copyleft license for
|
||||||
|
software and other kinds of works, specifically designed to ensure
|
||||||
|
cooperation with the community in the case of network server software.
|
||||||
|
|
||||||
|
The licenses for most software and other practical works are designed
|
||||||
|
to take away your freedom to share and change the works. By contrast,
|
||||||
|
our General Public Licenses are intended to guarantee your freedom to
|
||||||
|
share and change all versions of a program--to make sure it remains free
|
||||||
|
software for all its users.
|
||||||
|
|
||||||
|
When we speak of free software, we are referring to freedom, not
|
||||||
|
price. Our General Public Licenses are designed to make sure that you
|
||||||
|
have the freedom to distribute copies of free software (and charge for
|
||||||
|
them if you wish), that you receive source code or can get it if you
|
||||||
|
want it, that you can change the software or use pieces of it in new
|
||||||
|
free programs, and that you know you can do these things.
|
||||||
|
|
||||||
|
Developers that use our General Public Licenses protect your rights
|
||||||
|
with two steps: (1) assert copyright on the software, and (2) offer
|
||||||
|
you this License which gives you legal permission to copy, distribute
|
||||||
|
and/or modify the software.
|
||||||
|
|
||||||
|
A secondary benefit of defending all users' freedom is that
|
||||||
|
improvements made in alternate versions of the program, if they
|
||||||
|
receive widespread use, become available for other developers to
|
||||||
|
incorporate. Many developers of free software are heartened and
|
||||||
|
encouraged by the resulting cooperation. However, in the case of
|
||||||
|
software used on network servers, this result may fail to come about.
|
||||||
|
The GNU General Public License permits making a modified version and
|
||||||
|
letting the public access it on a server without ever releasing its
|
||||||
|
source code to the public.
|
||||||
|
|
||||||
|
The GNU Affero General Public License is designed specifically to
|
||||||
|
ensure that, in such cases, the modified source code becomes available
|
||||||
|
to the community. It requires the operator of a network server to
|
||||||
|
provide the source code of the modified version running there to the
|
||||||
|
users of that server. Therefore, public use of a modified version, on
|
||||||
|
a publicly accessible server, gives the public access to the source
|
||||||
|
code of the modified version.
|
||||||
|
|
||||||
|
An older license, called the Affero General Public License and
|
||||||
|
published by Affero, was designed to accomplish similar goals. This is
|
||||||
|
a different license, not a version of the Affero GPL, but Affero has
|
||||||
|
released a new version of the Affero GPL which permits relicensing under
|
||||||
|
this license.
|
||||||
|
|
||||||
|
The precise terms and conditions for copying, distribution and
|
||||||
|
modification follow.
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
0. Definitions.
|
||||||
|
|
||||||
|
"This License" refers to version 3 of the GNU Affero General Public License.
|
||||||
|
|
||||||
|
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||||
|
works, such as semiconductor masks.
|
||||||
|
|
||||||
|
"The Program" refers to any copyrightable work licensed under this
|
||||||
|
License. Each licensee is addressed as "you". "Licensees" and
|
||||||
|
"recipients" may be individuals or organizations.
|
||||||
|
|
||||||
|
To "modify" a work means to copy from or adapt all or part of the work
|
||||||
|
in a fashion requiring copyright permission, other than the making of an
|
||||||
|
exact copy. The resulting work is called a "modified version" of the
|
||||||
|
earlier work or a work "based on" the earlier work.
|
||||||
|
|
||||||
|
A "covered work" means either the unmodified Program or a work based
|
||||||
|
on the Program.
|
||||||
|
|
||||||
|
To "propagate" a work means to do anything with it that, without
|
||||||
|
permission, would make you directly or secondarily liable for
|
||||||
|
infringement under applicable copyright law, except executing it on a
|
||||||
|
computer or modifying a private copy. Propagation includes copying,
|
||||||
|
distribution (with or without modification), making available to the
|
||||||
|
public, and in some countries other activities as well.
|
||||||
|
|
||||||
|
To "convey" a work means any kind of propagation that enables other
|
||||||
|
parties to make or receive copies. Mere interaction with a user through
|
||||||
|
a computer network, with no transfer of a copy, is not conveying.
|
||||||
|
|
||||||
|
An interactive user interface displays "Appropriate Legal Notices"
|
||||||
|
to the extent that it includes a convenient and prominently visible
|
||||||
|
feature that (1) displays an appropriate copyright notice, and (2)
|
||||||
|
tells the user that there is no warranty for the work (except to the
|
||||||
|
extent that warranties are provided), that licensees may convey the
|
||||||
|
work under this License, and how to view a copy of this License. If
|
||||||
|
the interface presents a list of user commands or options, such as a
|
||||||
|
menu, a prominent item in the list meets this criterion.
|
||||||
|
|
||||||
|
1. Source Code.
|
||||||
|
|
||||||
|
The "source code" for a work means the preferred form of the work
|
||||||
|
for making modifications to it. "Object code" means any non-source
|
||||||
|
form of a work.
|
||||||
|
|
||||||
|
A "Standard Interface" means an interface that either is an official
|
||||||
|
standard defined by a recognized standards body, or, in the case of
|
||||||
|
interfaces specified for a particular programming language, one that
|
||||||
|
is widely used among developers working in that language.
|
||||||
|
|
||||||
|
The "System Libraries" of an executable work include anything, other
|
||||||
|
than the work as a whole, that (a) is included in the normal form of
|
||||||
|
packaging a Major Component, but which is not part of that Major
|
||||||
|
Component, and (b) serves only to enable use of the work with that
|
||||||
|
Major Component, or to implement a Standard Interface for which an
|
||||||
|
implementation is available to the public in source code form. A
|
||||||
|
"Major Component", in this context, means a major essential component
|
||||||
|
(kernel, window system, and so on) of the specific operating system
|
||||||
|
(if any) on which the executable work runs, or a compiler used to
|
||||||
|
produce the work, or an object code interpreter used to run it.
|
||||||
|
|
||||||
|
The "Corresponding Source" for a work in object code form means all
|
||||||
|
the source code needed to generate, install, and (for an executable
|
||||||
|
work) run the object code and to modify the work, including scripts to
|
||||||
|
control those activities. However, it does not include the work's
|
||||||
|
System Libraries, or general-purpose tools or generally available free
|
||||||
|
programs which are used unmodified in performing those activities but
|
||||||
|
which are not part of the work. For example, Corresponding Source
|
||||||
|
includes interface definition files associated with source files for
|
||||||
|
the work, and the source code for shared libraries and dynamically
|
||||||
|
linked subprograms that the work is specifically designed to require,
|
||||||
|
such as by intimate data communication or control flow between those
|
||||||
|
subprograms and other parts of the work.
|
||||||
|
|
||||||
|
The Corresponding Source need not include anything that users
|
||||||
|
can regenerate automatically from other parts of the Corresponding
|
||||||
|
Source.
|
||||||
|
|
||||||
|
The Corresponding Source for a work in source code form is that
|
||||||
|
same work.
|
||||||
|
|
||||||
|
2. Basic Permissions.
|
||||||
|
|
||||||
|
All rights granted under this License are granted for the term of
|
||||||
|
copyright on the Program, and are irrevocable provided the stated
|
||||||
|
conditions are met. This License explicitly affirms your unlimited
|
||||||
|
permission to run the unmodified Program. The output from running a
|
||||||
|
covered work is covered by this License only if the output, given its
|
||||||
|
content, constitutes a covered work. This License acknowledges your
|
||||||
|
rights of fair use or other equivalent, as provided by copyright law.
|
||||||
|
|
||||||
|
You may make, run and propagate covered works that you do not
|
||||||
|
convey, without conditions so long as your license otherwise remains
|
||||||
|
in force. You may convey covered works to others for the sole purpose
|
||||||
|
of having them make modifications exclusively for you, or provide you
|
||||||
|
with facilities for running those works, provided that you comply with
|
||||||
|
the terms of this License in conveying all material for which you do
|
||||||
|
not control copyright. Those thus making or running the covered works
|
||||||
|
for you must do so exclusively on your behalf, under your direction
|
||||||
|
and control, on terms that prohibit them from making any copies of
|
||||||
|
your copyrighted material outside their relationship with you.
|
||||||
|
|
||||||
|
Conveying under any other circumstances is permitted solely under
|
||||||
|
the conditions stated below. Sublicensing is not allowed; section 10
|
||||||
|
makes it unnecessary.
|
||||||
|
|
||||||
|
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||||
|
|
||||||
|
No covered work shall be deemed part of an effective technological
|
||||||
|
measure under any applicable law fulfilling obligations under article
|
||||||
|
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||||
|
similar laws prohibiting or restricting circumvention of such
|
||||||
|
measures.
|
||||||
|
|
||||||
|
When you convey a covered work, you waive any legal power to forbid
|
||||||
|
circumvention of technological measures to the extent such circumvention
|
||||||
|
is effected by exercising rights under this License with respect to
|
||||||
|
the covered work, and you disclaim any intention to limit operation or
|
||||||
|
modification of the work as a means of enforcing, against the work's
|
||||||
|
users, your or third parties' legal rights to forbid circumvention of
|
||||||
|
technological measures.
|
||||||
|
|
||||||
|
4. Conveying Verbatim Copies.
|
||||||
|
|
||||||
|
You may convey verbatim copies of the Program's source code as you
|
||||||
|
receive it, in any medium, provided that you conspicuously and
|
||||||
|
appropriately publish on each copy an appropriate copyright notice;
|
||||||
|
keep intact all notices stating that this License and any
|
||||||
|
non-permissive terms added in accord with section 7 apply to the code;
|
||||||
|
keep intact all notices of the absence of any warranty; and give all
|
||||||
|
recipients a copy of this License along with the Program.
|
||||||
|
|
||||||
|
You may charge any price or no price for each copy that you convey,
|
||||||
|
and you may offer support or warranty protection for a fee.
|
||||||
|
|
||||||
|
5. Conveying Modified Source Versions.
|
||||||
|
|
||||||
|
You may convey a work based on the Program, or the modifications to
|
||||||
|
produce it from the Program, in the form of source code under the
|
||||||
|
terms of section 4, provided that you also meet all of these conditions:
|
||||||
|
|
||||||
|
a) The work must carry prominent notices stating that you modified
|
||||||
|
it, and giving a relevant date.
|
||||||
|
|
||||||
|
b) The work must carry prominent notices stating that it is
|
||||||
|
released under this License and any conditions added under section
|
||||||
|
7. This requirement modifies the requirement in section 4 to
|
||||||
|
"keep intact all notices".
|
||||||
|
|
||||||
|
c) You must license the entire work, as a whole, under this
|
||||||
|
License to anyone who comes into possession of a copy. This
|
||||||
|
License will therefore apply, along with any applicable section 7
|
||||||
|
additional terms, to the whole of the work, and all its parts,
|
||||||
|
regardless of how they are packaged. This License gives no
|
||||||
|
permission to license the work in any other way, but it does not
|
||||||
|
invalidate such permission if you have separately received it.
|
||||||
|
|
||||||
|
d) If the work has interactive user interfaces, each must display
|
||||||
|
Appropriate Legal Notices; however, if the Program has interactive
|
||||||
|
interfaces that do not display Appropriate Legal Notices, your
|
||||||
|
work need not make them do so.
|
||||||
|
|
||||||
|
A compilation of a covered work with other separate and independent
|
||||||
|
works, which are not by their nature extensions of the covered work,
|
||||||
|
and which are not combined with it such as to form a larger program,
|
||||||
|
in or on a volume of a storage or distribution medium, is called an
|
||||||
|
"aggregate" if the compilation and its resulting copyright are not
|
||||||
|
used to limit the access or legal rights of the compilation's users
|
||||||
|
beyond what the individual works permit. Inclusion of a covered work
|
||||||
|
in an aggregate does not cause this License to apply to the other
|
||||||
|
parts of the aggregate.
|
||||||
|
|
||||||
|
6. Conveying Non-Source Forms.
|
||||||
|
|
||||||
|
You may convey a covered work in object code form under the terms
|
||||||
|
of sections 4 and 5, provided that you also convey the
|
||||||
|
machine-readable Corresponding Source under the terms of this License,
|
||||||
|
in one of these ways:
|
||||||
|
|
||||||
|
a) Convey the object code in, or embodied in, a physical product
|
||||||
|
(including a physical distribution medium), accompanied by the
|
||||||
|
Corresponding Source fixed on a durable physical medium
|
||||||
|
customarily used for software interchange.
|
||||||
|
|
||||||
|
b) Convey the object code in, or embodied in, a physical product
|
||||||
|
(including a physical distribution medium), accompanied by a
|
||||||
|
written offer, valid for at least three years and valid for as
|
||||||
|
long as you offer spare parts or customer support for that product
|
||||||
|
model, to give anyone who possesses the object code either (1) a
|
||||||
|
copy of the Corresponding Source for all the software in the
|
||||||
|
product that is covered by this License, on a durable physical
|
||||||
|
medium customarily used for software interchange, for a price no
|
||||||
|
more than your reasonable cost of physically performing this
|
||||||
|
conveying of source, or (2) access to copy the
|
||||||
|
Corresponding Source from a network server at no charge.
|
||||||
|
|
||||||
|
c) Convey individual copies of the object code with a copy of the
|
||||||
|
written offer to provide the Corresponding Source. This
|
||||||
|
alternative is allowed only occasionally and noncommercially, and
|
||||||
|
only if you received the object code with such an offer, in accord
|
||||||
|
with subsection 6b.
|
||||||
|
|
||||||
|
d) Convey the object code by offering access from a designated
|
||||||
|
place (gratis or for a charge), and offer equivalent access to the
|
||||||
|
Corresponding Source in the same way through the same place at no
|
||||||
|
further charge. You need not require recipients to copy the
|
||||||
|
Corresponding Source along with the object code. If the place to
|
||||||
|
copy the object code is a network server, the Corresponding Source
|
||||||
|
may be on a different server (operated by you or a third party)
|
||||||
|
that supports equivalent copying facilities, provided you maintain
|
||||||
|
clear directions next to the object code saying where to find the
|
||||||
|
Corresponding Source. Regardless of what server hosts the
|
||||||
|
Corresponding Source, you remain obligated to ensure that it is
|
||||||
|
available for as long as needed to satisfy these requirements.
|
||||||
|
|
||||||
|
e) Convey the object code using peer-to-peer transmission, provided
|
||||||
|
you inform other peers where the object code and Corresponding
|
||||||
|
Source of the work are being offered to the general public at no
|
||||||
|
charge under subsection 6d.
|
||||||
|
|
||||||
|
A separable portion of the object code, whose source code is excluded
|
||||||
|
from the Corresponding Source as a System Library, need not be
|
||||||
|
included in conveying the object code work.
|
||||||
|
|
||||||
|
A "User Product" is either (1) a "consumer product", which means any
|
||||||
|
tangible personal property which is normally used for personal, family,
|
||||||
|
or household purposes, or (2) anything designed or sold for incorporation
|
||||||
|
into a dwelling. In determining whether a product is a consumer product,
|
||||||
|
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||||
|
product received by a particular user, "normally used" refers to a
|
||||||
|
typical or common use of that class of product, regardless of the status
|
||||||
|
of the particular user or of the way in which the particular user
|
||||||
|
actually uses, or expects or is expected to use, the product. A product
|
||||||
|
is a consumer product regardless of whether the product has substantial
|
||||||
|
commercial, industrial or non-consumer uses, unless such uses represent
|
||||||
|
the only significant mode of use of the product.
|
||||||
|
|
||||||
|
"Installation Information" for a User Product means any methods,
|
||||||
|
procedures, authorization keys, or other information required to install
|
||||||
|
and execute modified versions of a covered work in that User Product from
|
||||||
|
a modified version of its Corresponding Source. The information must
|
||||||
|
suffice to ensure that the continued functioning of the modified object
|
||||||
|
code is in no case prevented or interfered with solely because
|
||||||
|
modification has been made.
|
||||||
|
|
||||||
|
If you convey an object code work under this section in, or with, or
|
||||||
|
specifically for use in, a User Product, and the conveying occurs as
|
||||||
|
part of a transaction in which the right of possession and use of the
|
||||||
|
User Product is transferred to the recipient in perpetuity or for a
|
||||||
|
fixed term (regardless of how the transaction is characterized), the
|
||||||
|
Corresponding Source conveyed under this section must be accompanied
|
||||||
|
by the Installation Information. But this requirement does not apply
|
||||||
|
if neither you nor any third party retains the ability to install
|
||||||
|
modified object code on the User Product (for example, the work has
|
||||||
|
been installed in ROM).
|
||||||
|
|
||||||
|
The requirement to provide Installation Information does not include a
|
||||||
|
requirement to continue to provide support service, warranty, or updates
|
||||||
|
for a work that has been modified or installed by the recipient, or for
|
||||||
|
the User Product in which it has been modified or installed. Access to a
|
||||||
|
network may be denied when the modification itself materially and
|
||||||
|
adversely affects the operation of the network or violates the rules and
|
||||||
|
protocols for communication across the network.
|
||||||
|
|
||||||
|
Corresponding Source conveyed, and Installation Information provided,
|
||||||
|
in accord with this section must be in a format that is publicly
|
||||||
|
documented (and with an implementation available to the public in
|
||||||
|
source code form), and must require no special password or key for
|
||||||
|
unpacking, reading or copying.
|
||||||
|
|
||||||
|
7. Additional Terms.
|
||||||
|
|
||||||
|
"Additional permissions" are terms that supplement the terms of this
|
||||||
|
License by making exceptions from one or more of its conditions.
|
||||||
|
Additional permissions that are applicable to the entire Program shall
|
||||||
|
be treated as though they were included in this License, to the extent
|
||||||
|
that they are valid under applicable law. If additional permissions
|
||||||
|
apply only to part of the Program, that part may be used separately
|
||||||
|
under those permissions, but the entire Program remains governed by
|
||||||
|
this License without regard to the additional permissions.
|
||||||
|
|
||||||
|
When you convey a copy of a covered work, you may at your option
|
||||||
|
remove any additional permissions from that copy, or from any part of
|
||||||
|
it. (Additional permissions may be written to require their own
|
||||||
|
removal in certain cases when you modify the work.) You may place
|
||||||
|
additional permissions on material, added by you to a covered work,
|
||||||
|
for which you have or can give appropriate copyright permission.
|
||||||
|
|
||||||
|
Notwithstanding any other provision of this License, for material you
|
||||||
|
add to a covered work, you may (if authorized by the copyright holders of
|
||||||
|
that material) supplement the terms of this License with terms:
|
||||||
|
|
||||||
|
a) Disclaiming warranty or limiting liability differently from the
|
||||||
|
terms of sections 15 and 16 of this License; or
|
||||||
|
|
||||||
|
b) Requiring preservation of specified reasonable legal notices or
|
||||||
|
author attributions in that material or in the Appropriate Legal
|
||||||
|
Notices displayed by works containing it; or
|
||||||
|
|
||||||
|
c) Prohibiting misrepresentation of the origin of that material, or
|
||||||
|
requiring that modified versions of such material be marked in
|
||||||
|
reasonable ways as different from the original version; or
|
||||||
|
|
||||||
|
d) Limiting the use for publicity purposes of names of licensors or
|
||||||
|
authors of the material; or
|
||||||
|
|
||||||
|
e) Declining to grant rights under trademark law for use of some
|
||||||
|
trade names, trademarks, or service marks; or
|
||||||
|
|
||||||
|
f) Requiring indemnification of licensors and authors of that
|
||||||
|
material by anyone who conveys the material (or modified versions of
|
||||||
|
it) with contractual assumptions of liability to the recipient, for
|
||||||
|
any liability that these contractual assumptions directly impose on
|
||||||
|
those licensors and authors.
|
||||||
|
|
||||||
|
All other non-permissive additional terms are considered "further
|
||||||
|
restrictions" within the meaning of section 10. If the Program as you
|
||||||
|
received it, or any part of it, contains a notice stating that it is
|
||||||
|
governed by this License along with a term that is a further
|
||||||
|
restriction, you may remove that term. If a license document contains
|
||||||
|
a further restriction but permits relicensing or conveying under this
|
||||||
|
License, you may add to a covered work material governed by the terms
|
||||||
|
of that license document, provided that the further restriction does
|
||||||
|
not survive such relicensing or conveying.
|
||||||
|
|
||||||
|
If you add terms to a covered work in accord with this section, you
|
||||||
|
must place, in the relevant source files, a statement of the
|
||||||
|
additional terms that apply to those files, or a notice indicating
|
||||||
|
where to find the applicable terms.
|
||||||
|
|
||||||
|
Additional terms, permissive or non-permissive, may be stated in the
|
||||||
|
form of a separately written license, or stated as exceptions;
|
||||||
|
the above requirements apply either way.
|
||||||
|
|
||||||
|
8. Termination.
|
||||||
|
|
||||||
|
You may not propagate or modify a covered work except as expressly
|
||||||
|
provided under this License. Any attempt otherwise to propagate or
|
||||||
|
modify it is void, and will automatically terminate your rights under
|
||||||
|
this License (including any patent licenses granted under the third
|
||||||
|
paragraph of section 11).
|
||||||
|
|
||||||
|
However, if you cease all violation of this License, then your
|
||||||
|
license from a particular copyright holder is reinstated (a)
|
||||||
|
provisionally, unless and until the copyright holder explicitly and
|
||||||
|
finally terminates your license, and (b) permanently, if the copyright
|
||||||
|
holder fails to notify you of the violation by some reasonable means
|
||||||
|
prior to 60 days after the cessation.
|
||||||
|
|
||||||
|
Moreover, your license from a particular copyright holder is
|
||||||
|
reinstated permanently if the copyright holder notifies you of the
|
||||||
|
violation by some reasonable means, this is the first time you have
|
||||||
|
received notice of violation of this License (for any work) from that
|
||||||
|
copyright holder, and you cure the violation prior to 30 days after
|
||||||
|
your receipt of the notice.
|
||||||
|
|
||||||
|
Termination of your rights under this section does not terminate the
|
||||||
|
licenses of parties who have received copies or rights from you under
|
||||||
|
this License. If your rights have been terminated and not permanently
|
||||||
|
reinstated, you do not qualify to receive new licenses for the same
|
||||||
|
material under section 10.
|
||||||
|
|
||||||
|
9. Acceptance Not Required for Having Copies.
|
||||||
|
|
||||||
|
You are not required to accept this License in order to receive or
|
||||||
|
run a copy of the Program. Ancillary propagation of a covered work
|
||||||
|
occurring solely as a consequence of using peer-to-peer transmission
|
||||||
|
to receive a copy likewise does not require acceptance. However,
|
||||||
|
nothing other than this License grants you permission to propagate or
|
||||||
|
modify any covered work. These actions infringe copyright if you do
|
||||||
|
not accept this License. Therefore, by modifying or propagating a
|
||||||
|
covered work, you indicate your acceptance of this License to do so.
|
||||||
|
|
||||||
|
10. Automatic Licensing of Downstream Recipients.
|
||||||
|
|
||||||
|
Each time you convey a covered work, the recipient automatically
|
||||||
|
receives a license from the original licensors, to run, modify and
|
||||||
|
propagate that work, subject to this License. You are not responsible
|
||||||
|
for enforcing compliance by third parties with this License.
|
||||||
|
|
||||||
|
An "entity transaction" is a transaction transferring control of an
|
||||||
|
organization, or substantially all assets of one, or subdividing an
|
||||||
|
organization, or merging organizations. If propagation of a covered
|
||||||
|
work results from an entity transaction, each party to that
|
||||||
|
transaction who receives a copy of the work also receives whatever
|
||||||
|
licenses to the work the party's predecessor in interest had or could
|
||||||
|
give under the previous paragraph, plus a right to possession of the
|
||||||
|
Corresponding Source of the work from the predecessor in interest, if
|
||||||
|
the predecessor has it or can get it with reasonable efforts.
|
||||||
|
|
||||||
|
You may not impose any further restrictions on the exercise of the
|
||||||
|
rights granted or affirmed under this License. For example, you may
|
||||||
|
not impose a license fee, royalty, or other charge for exercise of
|
||||||
|
rights granted under this License, and you may not initiate litigation
|
||||||
|
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||||
|
any patent claim is infringed by making, using, selling, offering for
|
||||||
|
sale, or importing the Program or any portion of it.
|
||||||
|
|
||||||
|
11. Patents.
|
||||||
|
|
||||||
|
A "contributor" is a copyright holder who authorizes use under this
|
||||||
|
License of the Program or a work on which the Program is based. The
|
||||||
|
work thus licensed is called the contributor's "contributor version".
|
||||||
|
|
||||||
|
A contributor's "essential patent claims" are all patent claims
|
||||||
|
owned or controlled by the contributor, whether already acquired or
|
||||||
|
hereafter acquired, that would be infringed by some manner, permitted
|
||||||
|
by this License, of making, using, or selling its contributor version,
|
||||||
|
but do not include claims that would be infringed only as a
|
||||||
|
consequence of further modification of the contributor version. For
|
||||||
|
purposes of this definition, "control" includes the right to grant
|
||||||
|
patent sublicenses in a manner consistent with the requirements of
|
||||||
|
this License.
|
||||||
|
|
||||||
|
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||||
|
patent license under the contributor's essential patent claims, to
|
||||||
|
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||||
|
propagate the contents of its contributor version.
|
||||||
|
|
||||||
|
In the following three paragraphs, a "patent license" is any express
|
||||||
|
agreement or commitment, however denominated, not to enforce a patent
|
||||||
|
(such as an express permission to practice a patent or covenant not to
|
||||||
|
sue for patent infringement). To "grant" such a patent license to a
|
||||||
|
party means to make such an agreement or commitment not to enforce a
|
||||||
|
patent against the party.
|
||||||
|
|
||||||
|
If you convey a covered work, knowingly relying on a patent license,
|
||||||
|
and the Corresponding Source of the work is not available for anyone
|
||||||
|
to copy, free of charge and under the terms of this License, through a
|
||||||
|
publicly available network server or other readily accessible means,
|
||||||
|
then you must either (1) cause the Corresponding Source to be so
|
||||||
|
available, or (2) arrange to deprive yourself of the benefit of the
|
||||||
|
patent license for this particular work, or (3) arrange, in a manner
|
||||||
|
consistent with the requirements of this License, to extend the patent
|
||||||
|
license to downstream recipients. "Knowingly relying" means you have
|
||||||
|
actual knowledge that, but for the patent license, your conveying the
|
||||||
|
covered work in a country, or your recipient's use of the covered work
|
||||||
|
in a country, would infringe one or more identifiable patents in that
|
||||||
|
country that you have reason to believe are valid.
|
||||||
|
|
||||||
|
If, pursuant to or in connection with a single transaction or
|
||||||
|
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||||
|
covered work, and grant a patent license to some of the parties
|
||||||
|
receiving the covered work authorizing them to use, propagate, modify
|
||||||
|
or convey a specific copy of the covered work, then the patent license
|
||||||
|
you grant is automatically extended to all recipients of the covered
|
||||||
|
work and works based on it.
|
||||||
|
|
||||||
|
A patent license is "discriminatory" if it does not include within
|
||||||
|
the scope of its coverage, prohibits the exercise of, or is
|
||||||
|
conditioned on the non-exercise of one or more of the rights that are
|
||||||
|
specifically granted under this License. You may not convey a covered
|
||||||
|
work if you are a party to an arrangement with a third party that is
|
||||||
|
in the business of distributing software, under which you make payment
|
||||||
|
to the third party based on the extent of your activity of conveying
|
||||||
|
the work, and under which the third party grants, to any of the
|
||||||
|
parties who would receive the covered work from you, a discriminatory
|
||||||
|
patent license (a) in connection with copies of the covered work
|
||||||
|
conveyed by you (or copies made from those copies), or (b) primarily
|
||||||
|
for and in connection with specific products or compilations that
|
||||||
|
contain the covered work, unless you entered into that arrangement,
|
||||||
|
or that patent license was granted, prior to 28 March 2007.
|
||||||
|
|
||||||
|
Nothing in this License shall be construed as excluding or limiting
|
||||||
|
any implied license or other defenses to infringement that may
|
||||||
|
otherwise be available to you under applicable patent law.
|
||||||
|
|
||||||
|
12. No Surrender of Others' Freedom.
|
||||||
|
|
||||||
|
If conditions are imposed on you (whether by court order, agreement or
|
||||||
|
otherwise) that contradict the conditions of this License, they do not
|
||||||
|
excuse you from the conditions of this License. If you cannot convey a
|
||||||
|
covered work so as to satisfy simultaneously your obligations under this
|
||||||
|
License and any other pertinent obligations, then as a consequence you may
|
||||||
|
not convey it at all. For example, if you agree to terms that obligate you
|
||||||
|
to collect a royalty for further conveying from those to whom you convey
|
||||||
|
the Program, the only way you could satisfy both those terms and this
|
||||||
|
License would be to refrain entirely from conveying the Program.
|
||||||
|
|
||||||
|
13. Remote Network Interaction; Use with the GNU General Public License.
|
||||||
|
|
||||||
|
Notwithstanding any other provision of this License, if you modify the
|
||||||
|
Program, your modified version must prominently offer all users
|
||||||
|
interacting with it remotely through a computer network (if your version
|
||||||
|
supports such interaction) an opportunity to receive the Corresponding
|
||||||
|
Source of your version by providing access to the Corresponding Source
|
||||||
|
from a network server at no charge, through some standard or customary
|
||||||
|
means of facilitating copying of software. This Corresponding Source
|
||||||
|
shall include the Corresponding Source for any work covered by version 3
|
||||||
|
of the GNU General Public License that is incorporated pursuant to the
|
||||||
|
following paragraph.
|
||||||
|
|
||||||
|
Notwithstanding any other provision of this License, you have
|
||||||
|
permission to link or combine any covered work with a work licensed
|
||||||
|
under version 3 of the GNU General Public License into a single
|
||||||
|
combined work, and to convey the resulting work. The terms of this
|
||||||
|
License will continue to apply to the part which is the covered work,
|
||||||
|
but the work with which it is combined will remain governed by version
|
||||||
|
3 of the GNU General Public License.
|
||||||
|
|
||||||
|
14. Revised Versions of this License.
|
||||||
|
|
||||||
|
The Free Software Foundation may publish revised and/or new versions of
|
||||||
|
the GNU Affero General Public License from time to time. Such new versions
|
||||||
|
will be similar in spirit to the present version, but may differ in detail to
|
||||||
|
address new problems or concerns.
|
||||||
|
|
||||||
|
Each version is given a distinguishing version number. If the
|
||||||
|
Program specifies that a certain numbered version of the GNU Affero General
|
||||||
|
Public License "or any later version" applies to it, you have the
|
||||||
|
option of following the terms and conditions either of that numbered
|
||||||
|
version or of any later version published by the Free Software
|
||||||
|
Foundation. If the Program does not specify a version number of the
|
||||||
|
GNU Affero General Public License, you may choose any version ever published
|
||||||
|
by the Free Software Foundation.
|
||||||
|
|
||||||
|
If the Program specifies that a proxy can decide which future
|
||||||
|
versions of the GNU Affero General Public License can be used, that proxy's
|
||||||
|
public statement of acceptance of a version permanently authorizes you
|
||||||
|
to choose that version for the Program.
|
||||||
|
|
||||||
|
Later license versions may give you additional or different
|
||||||
|
permissions. However, no additional obligations are imposed on any
|
||||||
|
author or copyright holder as a result of your choosing to follow a
|
||||||
|
later version.
|
||||||
|
|
||||||
|
15. Disclaimer of Warranty.
|
||||||
|
|
||||||
|
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||||
|
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||||
|
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||||
|
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||||
|
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||||
|
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||||
|
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||||
|
|
||||||
|
16. Limitation of Liability.
|
||||||
|
|
||||||
|
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||||
|
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||||
|
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||||
|
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||||
|
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||||
|
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||||
|
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||||
|
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||||
|
SUCH DAMAGES.
|
||||||
|
|
||||||
|
17. Interpretation of Sections 15 and 16.
|
||||||
|
|
||||||
|
If the disclaimer of warranty and limitation of liability provided
|
||||||
|
above cannot be given local legal effect according to their terms,
|
||||||
|
reviewing courts shall apply local law that most closely approximates
|
||||||
|
an absolute waiver of all civil liability in connection with the
|
||||||
|
Program, unless a warranty or assumption of liability accompanies a
|
||||||
|
copy of the Program in return for a fee.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
How to Apply These Terms to Your New Programs
|
||||||
|
|
||||||
|
If you develop a new program, and you want it to be of the greatest
|
||||||
|
possible use to the public, the best way to achieve this is to make it
|
||||||
|
free software which everyone can redistribute and change under these terms.
|
||||||
|
|
||||||
|
To do so, attach the following notices to the program. It is safest
|
||||||
|
to attach them to the start of each source file to most effectively
|
||||||
|
state the exclusion of warranty; and each file should have at least
|
||||||
|
the "copyright" line and a pointer to where the full notice is found.
|
||||||
|
|
||||||
|
<one line to give the program's name and a brief idea of what it does.>
|
||||||
|
Copyright (C) <year> <name of author>
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU Affero General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Affero General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
Also add information on how to contact you by electronic and paper mail.
|
||||||
|
|
||||||
|
If your software can interact with users remotely through a computer
|
||||||
|
network, you should also make sure that it provides a way for users to
|
||||||
|
get its source. For example, if your program is a web application, its
|
||||||
|
interface could display a "Source" link that leads users to an archive
|
||||||
|
of the code. There are many ways you could offer source, and different
|
||||||
|
solutions will be better for different programs; see section 13 for the
|
||||||
|
specific requirements.
|
||||||
|
|
||||||
|
You should also get your employer (if you work as a programmer) or school,
|
||||||
|
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||||
|
For more information on this, and how to apply and follow the GNU AGPL, see
|
||||||
|
<https://www.gnu.org/licenses/>.
|
||||||
@@ -627,7 +627,15 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
|
|||||||
|
|
||||||
// Set HTTP config fields for embedded IDP
|
// Set HTTP config fields for embedded IDP
|
||||||
httpConfig.AuthIssuer = mgmt.Auth.Issuer
|
httpConfig.AuthIssuer = mgmt.Auth.Issuer
|
||||||
|
httpConfig.AuthAudience = "netbird-dashboard"
|
||||||
|
httpConfig.AuthClientID = httpConfig.AuthAudience
|
||||||
|
httpConfig.CLIAuthAudience = "netbird-cli"
|
||||||
|
httpConfig.AuthUserIDClaim = "sub"
|
||||||
|
httpConfig.AuthKeysLocation = mgmt.Auth.Issuer + "/keys"
|
||||||
|
httpConfig.OIDCConfigEndpoint = mgmt.Auth.Issuer + "/.well-known/openid-configuration"
|
||||||
httpConfig.IdpSignKeyRefreshEnabled = mgmt.Auth.SignKeyRefreshEnabled
|
httpConfig.IdpSignKeyRefreshEnabled = mgmt.Auth.SignKeyRefreshEnabled
|
||||||
|
callbackURL := strings.TrimSuffix(httpConfig.AuthIssuer, "/oauth2")
|
||||||
|
httpConfig.AuthCallbackURL = callbackURL + types.ProxyCallbackEndpointFull
|
||||||
|
|
||||||
return &nbconfig.Config{
|
return &nbconfig.Config{
|
||||||
Stuns: stuns,
|
Stuns: stuns,
|
||||||
|
|||||||
@@ -62,6 +62,8 @@ Configuration is loaded from a YAML file specified with --config.`,
|
|||||||
func init() {
|
func init() {
|
||||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to YAML configuration file (required)")
|
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to YAML configuration file (required)")
|
||||||
_ = rootCmd.MarkPersistentFlagRequired("config")
|
_ = rootCmd.MarkPersistentFlagRequired("config")
|
||||||
|
|
||||||
|
rootCmd.AddCommand(newTokenCommands())
|
||||||
}
|
}
|
||||||
|
|
||||||
func Execute() error {
|
func Execute() error {
|
||||||
|
|||||||
60
combined/cmd/token.go
Normal file
60
combined/cmd/token.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/formatter/hook"
|
||||||
|
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newTokenCommands creates the token command tree with combined-specific store opener.
|
||||||
|
func newTokenCommands() *cobra.Command {
|
||||||
|
return tokencmd.NewCommands(withTokenStore)
|
||||||
|
}
|
||||||
|
|
||||||
|
// withTokenStore loads the combined YAML config, initializes the store, and calls fn.
|
||||||
|
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||||
|
if err := util.InitLog("error", "console"); err != nil {
|
||||||
|
return fmt.Errorf("init log: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||||
|
|
||||||
|
cfg, err := LoadConfig(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("load config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if dsn := cfg.Server.Store.DSN; dsn != "" {
|
||||||
|
switch strings.ToLower(cfg.Server.Store.Engine) {
|
||||||
|
case "postgres":
|
||||||
|
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
|
||||||
|
case "mysql":
|
||||||
|
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
datadir := cfg.Management.DataDir
|
||||||
|
engine := types.Engine(cfg.Management.Store.Engine)
|
||||||
|
|
||||||
|
s, err := store.NewStore(ctx, engine, datadir, nil, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create store: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := s.Close(ctx); err != nil {
|
||||||
|
log.Debugf("close store: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return fn(ctx, s)
|
||||||
|
}
|
||||||
2
go.mod
2
go.mod
@@ -42,6 +42,7 @@ require (
|
|||||||
github.com/cilium/ebpf v0.15.0
|
github.com/cilium/ebpf v0.15.0
|
||||||
github.com/coder/websocket v1.8.13
|
github.com/coder/websocket v1.8.13
|
||||||
github.com/coreos/go-iptables v0.7.0
|
github.com/coreos/go-iptables v0.7.0
|
||||||
|
github.com/coreos/go-oidc/v3 v3.14.1
|
||||||
github.com/creack/pty v1.1.24
|
github.com/creack/pty v1.1.24
|
||||||
github.com/dexidp/dex v0.0.0-00010101000000-000000000000
|
github.com/dexidp/dex v0.0.0-00010101000000-000000000000
|
||||||
github.com/dexidp/dex/api/v2 v2.4.0
|
github.com/dexidp/dex/api/v2 v2.4.0
|
||||||
@@ -167,7 +168,6 @@ require (
|
|||||||
github.com/containerd/containerd v1.7.29 // indirect
|
github.com/containerd/containerd v1.7.29 // indirect
|
||||||
github.com/containerd/log v0.1.0 // indirect
|
github.com/containerd/log v0.1.0 // indirect
|
||||||
github.com/containerd/platforms v0.2.1 // indirect
|
github.com/containerd/platforms v0.2.1 // indirect
|
||||||
github.com/coreos/go-oidc/v3 v3.14.1 // indirect
|
|
||||||
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
|
|||||||
@@ -166,6 +166,65 @@ read_proxy_docker_network() {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
read_enable_proxy() {
|
||||||
|
echo "" > /dev/stderr
|
||||||
|
echo "Do you want to enable the NetBird Proxy service?" > /dev/stderr
|
||||||
|
echo "The proxy exposes internal NetBird network resources to the internet." > /dev/stderr
|
||||||
|
echo -n "Enable proxy? [y/N]: " > /dev/stderr
|
||||||
|
read -r CHOICE < /dev/tty
|
||||||
|
|
||||||
|
if [[ "$CHOICE" =~ ^[Yy]$ ]]; then
|
||||||
|
echo "true"
|
||||||
|
else
|
||||||
|
echo "false"
|
||||||
|
fi
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
read_proxy_domain() {
|
||||||
|
echo "" > /dev/stderr
|
||||||
|
echo "WARNING: The proxy domain MUST NOT be a subdomain of the NetBird management" > /dev/stderr
|
||||||
|
echo "domain ($NETBIRD_DOMAIN). Using a subdomain will cause TLS certificate conflicts." > /dev/stderr
|
||||||
|
echo "" > /dev/stderr
|
||||||
|
echo -n "Enter the domain for the NetBird Proxy (e.g. proxy.my-domain.com): " > /dev/stderr
|
||||||
|
read -r READ_PROXY_DOMAIN < /dev/tty
|
||||||
|
|
||||||
|
if [[ -z "$READ_PROXY_DOMAIN" ]]; then
|
||||||
|
echo "The proxy domain cannot be empty." > /dev/stderr
|
||||||
|
read_proxy_domain
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$READ_PROXY_DOMAIN" == "$NETBIRD_DOMAIN" ]]; then
|
||||||
|
echo "The proxy domain cannot be the same as the management domain ($NETBIRD_DOMAIN)." > /dev/stderr
|
||||||
|
read_proxy_domain
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$READ_PROXY_DOMAIN" == *".${NETBIRD_DOMAIN}" ]]; then
|
||||||
|
echo "The proxy domain cannot be a subdomain of the management domain ($NETBIRD_DOMAIN)." > /dev/stderr
|
||||||
|
read_proxy_domain
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "$READ_PROXY_DOMAIN"
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
read_traefik_acme_email() {
|
||||||
|
echo "" > /dev/stderr
|
||||||
|
echo "Enter your email for Let's Encrypt certificate notifications." > /dev/stderr
|
||||||
|
echo -n "Email address: " > /dev/stderr
|
||||||
|
read -r EMAIL < /dev/tty
|
||||||
|
if [[ -z "$EMAIL" ]]; then
|
||||||
|
echo "Email is required for Let's Encrypt." > /dev/stderr
|
||||||
|
read_traefik_acme_email
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
echo "$EMAIL"
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
get_bind_address() {
|
get_bind_address() {
|
||||||
if [[ "$BIND_LOCALHOST_ONLY" == "true" ]]; then
|
if [[ "$BIND_LOCALHOST_ONLY" == "true" ]]; then
|
||||||
echo "127.0.0.1"
|
echo "127.0.0.1"
|
||||||
@@ -248,16 +307,23 @@ initialize_default_values() {
|
|||||||
DASHBOARD_IMAGE="netbirdio/dashboard:latest"
|
DASHBOARD_IMAGE="netbirdio/dashboard:latest"
|
||||||
# Combined server replaces separate signal, relay, and management containers
|
# Combined server replaces separate signal, relay, and management containers
|
||||||
NETBIRD_SERVER_IMAGE="netbirdio/netbird-server:latest"
|
NETBIRD_SERVER_IMAGE="netbirdio/netbird-server:latest"
|
||||||
|
NETBIRD_PROXY_IMAGE="netbirdio/reverse-proxy:latest"
|
||||||
|
|
||||||
# Reverse proxy configuration
|
# Reverse proxy configuration
|
||||||
REVERSE_PROXY_TYPE="0"
|
REVERSE_PROXY_TYPE="0"
|
||||||
TRAEFIK_EXTERNAL_NETWORK=""
|
TRAEFIK_EXTERNAL_NETWORK=""
|
||||||
TRAEFIK_ENTRYPOINT="websecure"
|
TRAEFIK_ENTRYPOINT="websecure"
|
||||||
TRAEFIK_CERTRESOLVER=""
|
TRAEFIK_CERTRESOLVER=""
|
||||||
|
TRAEFIK_ACME_EMAIL=""
|
||||||
DASHBOARD_HOST_PORT="8080"
|
DASHBOARD_HOST_PORT="8080"
|
||||||
MANAGEMENT_HOST_PORT="8081" # Combined server port (management + signal + relay)
|
MANAGEMENT_HOST_PORT="8081" # Combined server port (management + signal + relay)
|
||||||
BIND_LOCALHOST_ONLY="true"
|
BIND_LOCALHOST_ONLY="true"
|
||||||
EXTERNAL_PROXY_NETWORK=""
|
EXTERNAL_PROXY_NETWORK=""
|
||||||
|
|
||||||
|
# NetBird Proxy configuration
|
||||||
|
ENABLE_PROXY="false"
|
||||||
|
PROXY_DOMAIN=""
|
||||||
|
PROXY_TOKEN=""
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -280,7 +346,16 @@ configure_reverse_proxy() {
|
|||||||
# Prompt for reverse proxy type
|
# Prompt for reverse proxy type
|
||||||
REVERSE_PROXY_TYPE=$(read_reverse_proxy_type)
|
REVERSE_PROXY_TYPE=$(read_reverse_proxy_type)
|
||||||
|
|
||||||
# Handle Traefik-specific prompts (only for external Traefik)
|
# Handle built-in Traefik prompts (option 0)
|
||||||
|
if [[ "$REVERSE_PROXY_TYPE" == "0" ]]; then
|
||||||
|
TRAEFIK_ACME_EMAIL=$(read_traefik_acme_email)
|
||||||
|
ENABLE_PROXY=$(read_enable_proxy)
|
||||||
|
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||||
|
PROXY_DOMAIN=$(read_proxy_domain)
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Handle external Traefik-specific prompts (option 1)
|
||||||
if [[ "$REVERSE_PROXY_TYPE" == "1" ]]; then
|
if [[ "$REVERSE_PROXY_TYPE" == "1" ]]; then
|
||||||
TRAEFIK_EXTERNAL_NETWORK=$(read_traefik_network)
|
TRAEFIK_EXTERNAL_NETWORK=$(read_traefik_network)
|
||||||
TRAEFIK_ENTRYPOINT=$(read_traefik_entrypoint)
|
TRAEFIK_ENTRYPOINT=$(read_traefik_entrypoint)
|
||||||
@@ -307,7 +382,7 @@ check_existing_installation() {
|
|||||||
echo "Generated files already exist, if you want to reinitialize the environment, please remove them first."
|
echo "Generated files already exist, if you want to reinitialize the environment, please remove them first."
|
||||||
echo "You can use the following commands:"
|
echo "You can use the following commands:"
|
||||||
echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes"
|
echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes"
|
||||||
echo " rm -f docker-compose.yml dashboard.env config.yaml nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt"
|
echo " rm -f docker-compose.yml dashboard.env config.yaml proxy.env nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt"
|
||||||
echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard."
|
echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
@@ -321,6 +396,12 @@ generate_configuration_files() {
|
|||||||
case "$REVERSE_PROXY_TYPE" in
|
case "$REVERSE_PROXY_TYPE" in
|
||||||
0)
|
0)
|
||||||
render_docker_compose_traefik_builtin > docker-compose.yml
|
render_docker_compose_traefik_builtin > docker-compose.yml
|
||||||
|
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||||
|
# Create placeholder proxy.env so docker-compose can validate
|
||||||
|
# This will be overwritten with the actual token after netbird-server starts
|
||||||
|
echo "# Placeholder - will be updated with token after netbird-server starts" > proxy.env
|
||||||
|
echo "NB_PROXY_TOKEN=placeholder" >> proxy.env
|
||||||
|
fi
|
||||||
;;
|
;;
|
||||||
1)
|
1)
|
||||||
render_docker_compose_traefik > docker-compose.yml
|
render_docker_compose_traefik > docker-compose.yml
|
||||||
@@ -357,12 +438,45 @@ start_services_and_show_instructions() {
|
|||||||
# For NPM, start containers first (NPM needs services running to create proxy)
|
# For NPM, start containers first (NPM needs services running to create proxy)
|
||||||
# For other external proxies, show instructions first and wait for user confirmation
|
# For other external proxies, show instructions first and wait for user confirmation
|
||||||
if [[ "$REVERSE_PROXY_TYPE" == "0" ]]; then
|
if [[ "$REVERSE_PROXY_TYPE" == "0" ]]; then
|
||||||
# Built-in Traefik - handles everything automatically (TLS via Let's Encrypt)
|
# Built-in Traefik - two-phase startup if proxy is enabled
|
||||||
echo -e "$MSG_STARTING_SERVICES"
|
echo -e "$MSG_STARTING_SERVICES"
|
||||||
$DOCKER_COMPOSE_COMMAND up -d
|
|
||||||
|
|
||||||
sleep 3
|
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||||
wait_management_proxy traefik
|
# Phase 1: Start core services (without proxy)
|
||||||
|
echo "Starting core services..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND up -d traefik dashboard netbird-server
|
||||||
|
|
||||||
|
sleep 3
|
||||||
|
wait_management_proxy traefik
|
||||||
|
|
||||||
|
# Phase 2: Create proxy token and start proxy
|
||||||
|
echo ""
|
||||||
|
echo "Creating proxy access token..."
|
||||||
|
# Use docker exec with bash to run the token command directly
|
||||||
|
PROXY_TOKEN=$($DOCKER_COMPOSE_COMMAND exec -T netbird-server \
|
||||||
|
/go/bin/netbird-server token create --name "default-proxy" --config /etc/netbird/config.yaml 2>/dev/null | grep "^Token:" | awk '{print $2}')
|
||||||
|
|
||||||
|
if [[ -z "$PROXY_TOKEN" ]]; then
|
||||||
|
echo "ERROR: Failed to create proxy token. Check netbird-server logs." > /dev/stderr
|
||||||
|
$DOCKER_COMPOSE_COMMAND logs --tail=20 netbird-server
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Proxy token created successfully."
|
||||||
|
|
||||||
|
# Generate proxy.env with the token
|
||||||
|
render_proxy_env > proxy.env
|
||||||
|
|
||||||
|
# Start proxy service
|
||||||
|
echo "Starting proxy service..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND up -d proxy
|
||||||
|
else
|
||||||
|
# No proxy - start all services at once
|
||||||
|
$DOCKER_COMPOSE_COMMAND up -d
|
||||||
|
|
||||||
|
sleep 3
|
||||||
|
wait_management_proxy traefik
|
||||||
|
fi
|
||||||
|
|
||||||
echo -e "$MSG_DONE"
|
echo -e "$MSG_DONE"
|
||||||
print_post_setup_instructions
|
print_post_setup_instructions
|
||||||
@@ -434,6 +548,45 @@ init_environment() {
|
|||||||
############################################
|
############################################
|
||||||
|
|
||||||
render_docker_compose_traefik_builtin() {
|
render_docker_compose_traefik_builtin() {
|
||||||
|
# Generate proxy service section if enabled
|
||||||
|
local proxy_service=""
|
||||||
|
local proxy_volumes=""
|
||||||
|
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||||
|
proxy_service="
|
||||||
|
# NetBird Proxy - exposes internal resources to the internet
|
||||||
|
proxy:
|
||||||
|
image: $NETBIRD_PROXY_IMAGE
|
||||||
|
container_name: netbird-proxy
|
||||||
|
# Hairpin NAT fix: route domain back to traefik's static IP within Docker
|
||||||
|
extra_hosts:
|
||||||
|
- \"$NETBIRD_DOMAIN:172.30.0.10\"
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [netbird]
|
||||||
|
depends_on:
|
||||||
|
- netbird-server
|
||||||
|
env_file:
|
||||||
|
- ./proxy.env
|
||||||
|
volumes:
|
||||||
|
- netbird_proxy_certs:/certs
|
||||||
|
labels:
|
||||||
|
# TCP passthrough for any unmatched domain (proxy handles its own TLS)
|
||||||
|
- traefik.enable=true
|
||||||
|
- traefik.tcp.routers.proxy-passthrough.entrypoints=websecure
|
||||||
|
- traefik.tcp.routers.proxy-passthrough.rule=HostSNI(\`*\`)
|
||||||
|
- traefik.tcp.routers.proxy-passthrough.tls.passthrough=true
|
||||||
|
- traefik.tcp.routers.proxy-passthrough.service=proxy-tls
|
||||||
|
- traefik.tcp.routers.proxy-passthrough.priority=1
|
||||||
|
- traefik.tcp.services.proxy-tls.loadbalancer.server.port=8443
|
||||||
|
logging:
|
||||||
|
driver: \"json-file\"
|
||||||
|
options:
|
||||||
|
max-size: \"500m\"
|
||||||
|
max-file: \"2\"
|
||||||
|
"
|
||||||
|
proxy_volumes="
|
||||||
|
netbird_proxy_certs:"
|
||||||
|
fi
|
||||||
|
|
||||||
cat <<EOF
|
cat <<EOF
|
||||||
services:
|
services:
|
||||||
# Traefik reverse proxy (automatic TLS via Let's Encrypt)
|
# Traefik reverse proxy (automatic TLS via Let's Encrypt)
|
||||||
@@ -441,18 +594,35 @@ services:
|
|||||||
image: traefik:v3.6
|
image: traefik:v3.6
|
||||||
container_name: netbird-traefik
|
container_name: netbird-traefik
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
networks: [netbird]
|
networks:
|
||||||
|
netbird:
|
||||||
|
ipv4_address: 172.30.0.10
|
||||||
command:
|
command:
|
||||||
|
# Logging
|
||||||
|
- "--log.level=INFO"
|
||||||
|
- "--accesslog=true"
|
||||||
|
# Docker provider
|
||||||
- "--providers.docker=true"
|
- "--providers.docker=true"
|
||||||
- "--providers.docker.exposedbydefault=false"
|
- "--providers.docker.exposedbydefault=false"
|
||||||
- "--providers.docker.network=netbird"
|
- "--providers.docker.network=netbird"
|
||||||
|
# Entrypoints
|
||||||
- "--entrypoints.web.address=:80"
|
- "--entrypoints.web.address=:80"
|
||||||
- "--entrypoints.websecure.address=:443"
|
- "--entrypoints.websecure.address=:443"
|
||||||
|
- "--entrypoints.websecure.allowACMEByPass=true"
|
||||||
|
# Disable timeouts for long-lived gRPC streams
|
||||||
- "--entrypoints.websecure.transport.respondingTimeouts.readTimeout=0"
|
- "--entrypoints.websecure.transport.respondingTimeouts.readTimeout=0"
|
||||||
|
- "--entrypoints.websecure.transport.respondingTimeouts.writeTimeout=0"
|
||||||
|
- "--entrypoints.websecure.transport.respondingTimeouts.idleTimeout=0"
|
||||||
|
# HTTP to HTTPS redirect
|
||||||
- "--entrypoints.web.http.redirections.entrypoint.to=websecure"
|
- "--entrypoints.web.http.redirections.entrypoint.to=websecure"
|
||||||
- "--entrypoints.web.http.redirections.entrypoint.scheme=https"
|
- "--entrypoints.web.http.redirections.entrypoint.scheme=https"
|
||||||
- "--certificatesresolvers.letsencrypt.acme.tlschallenge=true"
|
# Let's Encrypt ACME
|
||||||
|
- "--certificatesresolvers.letsencrypt.acme.email=$TRAEFIK_ACME_EMAIL"
|
||||||
- "--certificatesresolvers.letsencrypt.acme.storage=/letsencrypt/acme.json"
|
- "--certificatesresolvers.letsencrypt.acme.storage=/letsencrypt/acme.json"
|
||||||
|
- "--certificatesresolvers.letsencrypt.acme.tlschallenge=true"
|
||||||
|
# gRPC transport settings
|
||||||
|
- "--serverstransport.forwardingtimeouts.responseheadertimeout=0s"
|
||||||
|
- "--serverstransport.forwardingtimeouts.idleconntimeout=0s"
|
||||||
ports:
|
ports:
|
||||||
- '443:443'
|
- '443:443'
|
||||||
- '80:80'
|
- '80:80'
|
||||||
@@ -479,8 +649,9 @@ services:
|
|||||||
- traefik.http.routers.netbird-dashboard.entrypoints=websecure
|
- traefik.http.routers.netbird-dashboard.entrypoints=websecure
|
||||||
- traefik.http.routers.netbird-dashboard.tls=true
|
- traefik.http.routers.netbird-dashboard.tls=true
|
||||||
- traefik.http.routers.netbird-dashboard.tls.certresolver=letsencrypt
|
- traefik.http.routers.netbird-dashboard.tls.certresolver=letsencrypt
|
||||||
|
- traefik.http.routers.netbird-dashboard.service=dashboard
|
||||||
- traefik.http.routers.netbird-dashboard.priority=1
|
- traefik.http.routers.netbird-dashboard.priority=1
|
||||||
- traefik.http.services.netbird-dashboard.loadbalancer.server.port=80
|
- traefik.http.services.dashboard.loadbalancer.server.port=80
|
||||||
logging:
|
logging:
|
||||||
driver: "json-file"
|
driver: "json-file"
|
||||||
options:
|
options:
|
||||||
@@ -507,12 +678,14 @@ services:
|
|||||||
- traefik.http.routers.netbird-grpc.tls=true
|
- traefik.http.routers.netbird-grpc.tls=true
|
||||||
- traefik.http.routers.netbird-grpc.tls.certresolver=letsencrypt
|
- traefik.http.routers.netbird-grpc.tls.certresolver=letsencrypt
|
||||||
- traefik.http.routers.netbird-grpc.service=netbird-server-h2c
|
- traefik.http.routers.netbird-grpc.service=netbird-server-h2c
|
||||||
|
- traefik.http.routers.netbird-grpc.priority=100
|
||||||
# Backend router (relay, WebSocket, API, OAuth2)
|
# Backend router (relay, WebSocket, API, OAuth2)
|
||||||
- traefik.http.routers.netbird-backend.rule=Host(\`$NETBIRD_DOMAIN\`) && (PathPrefix(\`/relay\`) || PathPrefix(\`/ws-proxy/\`) || PathPrefix(\`/api\`) || PathPrefix(\`/oauth2\`))
|
- traefik.http.routers.netbird-backend.rule=Host(\`$NETBIRD_DOMAIN\`) && (PathPrefix(\`/relay\`) || PathPrefix(\`/ws-proxy/\`) || PathPrefix(\`/api\`) || PathPrefix(\`/oauth2\`))
|
||||||
- traefik.http.routers.netbird-backend.entrypoints=websecure
|
- traefik.http.routers.netbird-backend.entrypoints=websecure
|
||||||
- traefik.http.routers.netbird-backend.tls=true
|
- traefik.http.routers.netbird-backend.tls=true
|
||||||
- traefik.http.routers.netbird-backend.tls.certresolver=letsencrypt
|
- traefik.http.routers.netbird-backend.tls.certresolver=letsencrypt
|
||||||
- traefik.http.routers.netbird-backend.service=netbird-server
|
- traefik.http.routers.netbird-backend.service=netbird-server
|
||||||
|
- traefik.http.routers.netbird-backend.priority=100
|
||||||
# Services
|
# Services
|
||||||
- traefik.http.services.netbird-server.loadbalancer.server.port=80
|
- traefik.http.services.netbird-server.loadbalancer.server.port=80
|
||||||
- traefik.http.services.netbird-server-h2c.loadbalancer.server.port=80
|
- traefik.http.services.netbird-server-h2c.loadbalancer.server.port=80
|
||||||
@@ -522,13 +695,18 @@ services:
|
|||||||
options:
|
options:
|
||||||
max-size: "500m"
|
max-size: "500m"
|
||||||
max-file: "2"
|
max-file: "2"
|
||||||
|
${proxy_service}
|
||||||
volumes:
|
volumes:
|
||||||
netbird_data:
|
netbird_data:
|
||||||
netbird_traefik_letsencrypt:
|
netbird_traefik_letsencrypt:${proxy_volumes}
|
||||||
|
|
||||||
networks:
|
networks:
|
||||||
netbird:
|
netbird:
|
||||||
|
driver: bridge
|
||||||
|
ipam:
|
||||||
|
config:
|
||||||
|
- subnet: 172.30.0.0/24
|
||||||
|
gateway: 172.30.0.1
|
||||||
EOF
|
EOF
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
@@ -589,6 +767,28 @@ EOF
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
render_proxy_env() {
|
||||||
|
cat <<EOF
|
||||||
|
# NetBird Proxy Configuration
|
||||||
|
NB_PROXY_DEBUG_LOGS=false
|
||||||
|
# Use internal Docker network to connect to management (avoids hairpin NAT issues)
|
||||||
|
NB_PROXY_MANAGEMENT_ADDRESS=http://netbird-server:80
|
||||||
|
# Allow insecure gRPC connection to management (required for internal Docker network)
|
||||||
|
NB_PROXY_ALLOW_INSECURE=true
|
||||||
|
# Public URL where this proxy is reachable (used for cluster registration)
|
||||||
|
NB_PROXY_DOMAIN=$PROXY_DOMAIN
|
||||||
|
NB_PROXY_ADDRESS=:8443
|
||||||
|
NB_PROXY_TOKEN=$PROXY_TOKEN
|
||||||
|
NB_PROXY_CERTIFICATE_DIRECTORY=/certs
|
||||||
|
NB_PROXY_ACME_CERTIFICATES=true
|
||||||
|
NB_PROXY_ACME_CHALLENGE_TYPE=tls-alpn-01
|
||||||
|
NB_PROXY_OIDC_CLIENT_ID=netbird-proxy
|
||||||
|
NB_PROXY_OIDC_ENDPOINT=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/oauth2
|
||||||
|
NB_PROXY_OIDC_SCOPES=openid,profile,email
|
||||||
|
NB_PROXY_FORWARDED_PROTO=https
|
||||||
|
EOF
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
render_docker_compose_traefik() {
|
render_docker_compose_traefik() {
|
||||||
local network_name="${TRAEFIK_EXTERNAL_NETWORK:-netbird}"
|
local network_name="${TRAEFIK_EXTERNAL_NETWORK:-netbird}"
|
||||||
@@ -939,11 +1139,29 @@ EOF
|
|||||||
############################################
|
############################################
|
||||||
|
|
||||||
print_builtin_traefik_instructions() {
|
print_builtin_traefik_instructions() {
|
||||||
|
echo ""
|
||||||
|
echo "$MSG_SEPARATOR"
|
||||||
|
echo " NETBIRD SETUP COMPLETE"
|
||||||
|
echo "$MSG_SEPARATOR"
|
||||||
|
echo ""
|
||||||
echo "You can access the NetBird dashboard at $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
|
echo "You can access the NetBird dashboard at $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
|
||||||
echo "Follow the onboarding steps to set up your NetBird instance."
|
echo "Follow the onboarding steps to set up your NetBird instance."
|
||||||
echo ""
|
echo ""
|
||||||
echo "Traefik is handling TLS certificates automatically via Let's Encrypt."
|
echo "Traefik is handling TLS certificates automatically via Let's Encrypt."
|
||||||
echo "If you see certificate warnings, wait a moment for certificate issuance to complete."
|
echo "If you see certificate warnings, wait a moment for certificate issuance to complete."
|
||||||
|
echo ""
|
||||||
|
echo "Open ports:"
|
||||||
|
echo " - 443/tcp (HTTPS - all NetBird services)"
|
||||||
|
echo " - 80/tcp (HTTP - redirects to HTTPS)"
|
||||||
|
echo " - $NETBIRD_STUN_PORT/udp (STUN - required for NAT traversal)"
|
||||||
|
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||||
|
echo ""
|
||||||
|
echo "NetBird Proxy:"
|
||||||
|
echo " The proxy service is enabled and running."
|
||||||
|
echo " Any domain NOT matching $NETBIRD_DOMAIN will be passed through to the proxy."
|
||||||
|
echo " The proxy handles its own TLS certificates via ACME TLS-ALPN-01 challenge."
|
||||||
|
echo " Point your proxy domains (CNAMEs) to this server's IP address."
|
||||||
|
fi
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
17
management/Dockerfile.multistage
Normal file
17
management/Dockerfile.multistage
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
FROM golang:1.25-bookworm AS builder
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install build dependencies
|
||||||
|
RUN apt-get update && apt-get install -y gcc libc6-dev && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
COPY go.mod go.sum ./
|
||||||
|
RUN go mod download
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
RUN CGO_ENABLED=1 GOOS=linux go build -ldflags="-s -w" -o netbird-mgmt ./management
|
||||||
|
|
||||||
|
FROM ubuntu:24.04
|
||||||
|
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
|
||||||
|
ENTRYPOINT [ "/go/bin/netbird-mgmt","management"]
|
||||||
|
CMD ["--log-file", "console"]
|
||||||
|
COPY --from=builder /app/netbird-mgmt /go/bin/netbird-mgmt
|
||||||
@@ -19,6 +19,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/formatter/hook"
|
"github.com/netbirdio/netbird/formatter/hook"
|
||||||
"github.com/netbirdio/netbird/management/internals/server"
|
"github.com/netbirdio/netbird/management/internals/server"
|
||||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
@@ -213,11 +215,14 @@ func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
|||||||
// Set HttpConfig values from EmbeddedIdP
|
// Set HttpConfig values from EmbeddedIdP
|
||||||
cfg.HttpConfig.AuthIssuer = issuer
|
cfg.HttpConfig.AuthIssuer = issuer
|
||||||
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
|
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
|
||||||
|
cfg.HttpConfig.AuthClientID = cfg.HttpConfig.AuthAudience
|
||||||
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
|
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
|
||||||
cfg.HttpConfig.AuthUserIDClaim = "sub"
|
cfg.HttpConfig.AuthUserIDClaim = "sub"
|
||||||
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
|
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
|
||||||
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
|
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
|
||||||
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
|
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
|
||||||
|
callbackURL := strings.TrimSuffix(cfg.HttpConfig.AuthIssuer, "/oauth2")
|
||||||
|
cfg.HttpConfig.AuthCallbackURL = callbackURL + types.ProxyCallbackEndpointFull
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,4 +80,8 @@ func init() {
|
|||||||
migrationCmd.AddCommand(upCmd)
|
migrationCmd.AddCommand(upCmd)
|
||||||
|
|
||||||
rootCmd.AddCommand(migrationCmd)
|
rootCmd.AddCommand(migrationCmd)
|
||||||
|
|
||||||
|
tc := newTokenCommands()
|
||||||
|
tc.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
|
||||||
|
rootCmd.AddCommand(tc)
|
||||||
}
|
}
|
||||||
|
|||||||
55
management/cmd/token.go
Normal file
55
management/cmd/token.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/formatter/hook"
|
||||||
|
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||||
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
var tokenDatadir string
|
||||||
|
|
||||||
|
// newTokenCommands creates the token command tree with management-specific store opener.
|
||||||
|
func newTokenCommands() *cobra.Command {
|
||||||
|
cmd := tokencmd.NewCommands(withTokenStore)
|
||||||
|
cmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
|
||||||
|
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||||
|
if err := util.InitLog("error", "console"); err != nil {
|
||||||
|
return fmt.Errorf("init log: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||||
|
|
||||||
|
config, err := LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("load config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
datadir := config.Datadir
|
||||||
|
if tokenDatadir != "" {
|
||||||
|
datadir = tokenDatadir
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create store: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := s.Close(ctx); err != nil {
|
||||||
|
log.Debugf("close store: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return fn(ctx, s)
|
||||||
|
}
|
||||||
185
management/cmd/token/token.go
Normal file
185
management/cmd/token/token.go
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
// Package tokencmd provides reusable cobra commands for managing proxy access tokens.
|
||||||
|
// Both the management and combined binaries use these commands, each providing
|
||||||
|
// their own StoreOpener to handle config loading and store initialization.
|
||||||
|
package tokencmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strconv"
|
||||||
|
"text/tabwriter"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StoreOpener initializes a store from the command context and calls fn.
|
||||||
|
type StoreOpener func(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error
|
||||||
|
|
||||||
|
// NewCommands creates the token command tree with the given store opener.
|
||||||
|
// Returns the parent "token" command with create, list, and revoke subcommands.
|
||||||
|
func NewCommands(opener StoreOpener) *cobra.Command {
|
||||||
|
var (
|
||||||
|
tokenName string
|
||||||
|
tokenExpireIn string
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenCmd := &cobra.Command{
|
||||||
|
Use: "token",
|
||||||
|
Short: "Manage proxy access tokens",
|
||||||
|
Long: "Commands for creating, listing, and revoking proxy access tokens used by reverse proxy instances to authenticate with the management server.",
|
||||||
|
}
|
||||||
|
|
||||||
|
createCmd := &cobra.Command{
|
||||||
|
Use: "create",
|
||||||
|
Short: "Create a new proxy access token",
|
||||||
|
Long: "Creates a new proxy access token. The plain text token is displayed only once at creation time.",
|
||||||
|
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||||
|
return opener(cmd, func(ctx context.Context, s store.Store) error {
|
||||||
|
return runCreate(ctx, s, cmd.OutOrStdout(), tokenName, tokenExpireIn)
|
||||||
|
})
|
||||||
|
},
|
||||||
|
}
|
||||||
|
createCmd.Flags().StringVar(&tokenName, "name", "", "Name for the token (required)")
|
||||||
|
createCmd.Flags().StringVar(&tokenExpireIn, "expires-in", "", "Token expiration duration (e.g., 365d, 24h, 30d). Empty means no expiration")
|
||||||
|
if err := createCmd.MarkFlagRequired("name"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
listCmd := &cobra.Command{
|
||||||
|
Use: "list",
|
||||||
|
Aliases: []string{"ls"},
|
||||||
|
Short: "List all proxy access tokens",
|
||||||
|
Long: "Lists all proxy access tokens with their IDs, names, creation dates, expiration, and revocation status.",
|
||||||
|
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||||
|
return opener(cmd, func(ctx context.Context, s store.Store) error {
|
||||||
|
return runList(ctx, s, cmd.OutOrStdout())
|
||||||
|
})
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
revokeCmd := &cobra.Command{
|
||||||
|
Use: "revoke [token-id]",
|
||||||
|
Short: "Revoke a proxy access token",
|
||||||
|
Long: "Revokes a proxy access token by its ID. Revoked tokens can no longer be used for authentication.",
|
||||||
|
Args: cobra.ExactArgs(1),
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
return opener(cmd, func(ctx context.Context, s store.Store) error {
|
||||||
|
return runRevoke(ctx, s, cmd.OutOrStdout(), args[0])
|
||||||
|
})
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenCmd.AddCommand(createCmd, listCmd, revokeCmd)
|
||||||
|
return tokenCmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func runCreate(ctx context.Context, s store.Store, w io.Writer, name string, expireIn string) error {
|
||||||
|
expiresIn, err := ParseDuration(expireIn)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse expiration: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
generated, err := types.CreateNewProxyAccessToken(name, expiresIn, nil, "CLI")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.SaveProxyAccessToken(ctx, &generated.ProxyAccessToken); err != nil {
|
||||||
|
return fmt.Errorf("save token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = fmt.Fprintln(w, "Token created successfully!")
|
||||||
|
_, _ = fmt.Fprintf(w, "Token: %s\n", generated.PlainToken)
|
||||||
|
_, _ = fmt.Fprintln(w)
|
||||||
|
_, _ = fmt.Fprintln(w, "IMPORTANT: Save this token now. It will not be shown again.")
|
||||||
|
_, _ = fmt.Fprintf(w, "Token ID: %s\n", generated.ID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runList(ctx context.Context, s store.Store, out io.Writer) error {
|
||||||
|
tokens, err := s.GetAllProxyAccessTokens(ctx, store.LockingStrengthNone)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("list tokens: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tokens) == 0 {
|
||||||
|
_, _ = fmt.Fprintln(out, "No proxy access tokens found.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
w := tabwriter.NewWriter(out, 0, 0, 2, ' ', 0)
|
||||||
|
_, _ = fmt.Fprintln(w, "ID\tNAME\tCREATED\tEXPIRES\tLAST USED\tREVOKED")
|
||||||
|
_, _ = fmt.Fprintln(w, "--\t----\t-------\t-------\t---------\t-------")
|
||||||
|
|
||||||
|
for _, t := range tokens {
|
||||||
|
expires := "never"
|
||||||
|
if t.ExpiresAt != nil {
|
||||||
|
expires = t.ExpiresAt.Format("2006-01-02")
|
||||||
|
}
|
||||||
|
|
||||||
|
lastUsed := "never"
|
||||||
|
if t.LastUsed != nil {
|
||||||
|
lastUsed = t.LastUsed.Format("2006-01-02 15:04")
|
||||||
|
}
|
||||||
|
|
||||||
|
revoked := "no"
|
||||||
|
if t.Revoked {
|
||||||
|
revoked = "yes"
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
|
||||||
|
t.ID,
|
||||||
|
t.Name,
|
||||||
|
t.CreatedAt.Format("2006-01-02"),
|
||||||
|
expires,
|
||||||
|
lastUsed,
|
||||||
|
revoked,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Flush()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runRevoke(ctx context.Context, s store.Store, w io.Writer, tokenID string) error {
|
||||||
|
if err := s.RevokeProxyAccessToken(ctx, tokenID); err != nil {
|
||||||
|
return fmt.Errorf("revoke token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = fmt.Fprintf(w, "Token %s revoked successfully.\n", tokenID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseDuration parses a duration string with support for days (e.g., "30d", "365d").
|
||||||
|
// An empty string returns zero duration (no expiration).
|
||||||
|
func ParseDuration(s string) (time.Duration, error) {
|
||||||
|
if len(s) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if s[len(s)-1] == 'd' {
|
||||||
|
d, err := strconv.Atoi(s[:len(s)-1])
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("invalid day format: %s", s)
|
||||||
|
}
|
||||||
|
if d <= 0 {
|
||||||
|
return 0, fmt.Errorf("duration must be positive: %s", s)
|
||||||
|
}
|
||||||
|
return time.Duration(d) * 24 * time.Hour, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := time.ParseDuration(s)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if d <= 0 {
|
||||||
|
return 0, fmt.Errorf("duration must be positive: %s", s)
|
||||||
|
}
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
101
management/cmd/token/token_test.go
Normal file
101
management/cmd/token/token_test.go
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
package tokencmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseDuration(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected time.Duration
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty string returns zero",
|
||||||
|
input: "",
|
||||||
|
expected: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "days suffix",
|
||||||
|
input: "30d",
|
||||||
|
expected: 30 * 24 * time.Hour,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "one day",
|
||||||
|
input: "1d",
|
||||||
|
expected: 24 * time.Hour,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "365 days",
|
||||||
|
input: "365d",
|
||||||
|
expected: 365 * 24 * time.Hour,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hours via Go duration",
|
||||||
|
input: "24h",
|
||||||
|
expected: 24 * time.Hour,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "minutes via Go duration",
|
||||||
|
input: "30m",
|
||||||
|
expected: 30 * time.Minute,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex Go duration",
|
||||||
|
input: "1h30m",
|
||||||
|
expected: 90 * time.Minute,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid day format",
|
||||||
|
input: "abcd",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative days",
|
||||||
|
input: "-1d",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero days",
|
||||||
|
input: "0d",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-numeric days",
|
||||||
|
input: "xyzd",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative Go duration",
|
||||||
|
input: "-24h",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero Go duration",
|
||||||
|
input: "0s",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid Go duration",
|
||||||
|
input: "notaduration",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := ParseDuration(tt.input)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -174,6 +174,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
semaphore := make(chan struct{}, 10)
|
semaphore := make(chan struct{}, 10)
|
||||||
|
|
||||||
|
account.InjectProxyPolicies(ctx)
|
||||||
dnsCache := &cache.DNSConfigCache{}
|
dnsCache := &cache.DNSConfigCache{}
|
||||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||||
@@ -326,6 +327,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
return fmt.Errorf("failed to get validated peers: %v", err)
|
return fmt.Errorf("failed to get validated peers: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
account.InjectProxyPolicies(ctx)
|
||||||
dnsCache := &cache.DNSConfigCache{}
|
dnsCache := &cache.DNSConfigCache{}
|
||||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||||
@@ -441,6 +443,8 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
account.InjectProxyPolicies(ctx)
|
||||||
|
|
||||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, 0, err
|
return nil, nil, nil, 0, err
|
||||||
@@ -847,6 +851,7 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
|||||||
if c.experimentalNetworkMap(peer.AccountID) {
|
if c.experimentalNetworkMap(peer.AccountID) {
|
||||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
|
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
|
||||||
} else {
|
} else {
|
||||||
|
account.InjectProxyPolicies(ctx)
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
@@ -32,6 +33,7 @@ type Manager interface {
|
|||||||
SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator)
|
SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator)
|
||||||
SetAccountManager(accountManager account.Manager)
|
SetAccountManager(accountManager account.Manager)
|
||||||
GetPeerID(ctx context.Context, peerKey string) (string, error)
|
GetPeerID(ctx context.Context, peerKey string) (string, error)
|
||||||
|
CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type managerImpl struct {
|
type managerImpl struct {
|
||||||
@@ -182,3 +184,36 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
|
|||||||
func (m *managerImpl) GetPeerID(ctx context.Context, peerKey string) (string, error) {
|
func (m *managerImpl) GetPeerID(ctx context.Context, peerKey string) (string, error) {
|
||||||
return m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
|
return m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
|
||||||
|
existingPeerID, err := m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
|
||||||
|
if err == nil && existingPeerID != "" {
|
||||||
|
// Peer already exists
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
name := fmt.Sprintf("proxy-%s", xid.New().String())
|
||||||
|
peer := &peer.Peer{
|
||||||
|
Ephemeral: true,
|
||||||
|
ProxyMeta: peer.ProxyMeta{
|
||||||
|
Cluster: cluster,
|
||||||
|
Embedded: true,
|
||||||
|
},
|
||||||
|
Name: name,
|
||||||
|
Key: peerKey,
|
||||||
|
LoginExpirationEnabled: false,
|
||||||
|
InactivityExpirationEnabled: false,
|
||||||
|
Meta: peer.PeerSystemMeta{
|
||||||
|
Hostname: name,
|
||||||
|
GoOS: "proxy",
|
||||||
|
OS: "proxy",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create proxy peer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -162,3 +162,17 @@ func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController
|
|||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateProxyPeer mocks base method.
|
||||||
|
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
|
||||||
|
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,105 @@
|
|||||||
|
package accesslogs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AccessLogEntry struct {
|
||||||
|
ID string `gorm:"primaryKey"`
|
||||||
|
AccountID string `gorm:"index"`
|
||||||
|
ServiceID string `gorm:"index"`
|
||||||
|
Timestamp time.Time `gorm:"index"`
|
||||||
|
GeoLocation peer.Location `gorm:"embedded;embeddedPrefix:location_"`
|
||||||
|
Method string `gorm:"index"`
|
||||||
|
Host string `gorm:"index"`
|
||||||
|
Path string `gorm:"index"`
|
||||||
|
Duration time.Duration `gorm:"index"`
|
||||||
|
StatusCode int `gorm:"index"`
|
||||||
|
Reason string
|
||||||
|
UserId string `gorm:"index"`
|
||||||
|
AuthMethodUsed string `gorm:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromProto creates an AccessLogEntry from a proto.AccessLog
|
||||||
|
func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) {
|
||||||
|
a.ID = serviceLog.GetLogId()
|
||||||
|
a.ServiceID = serviceLog.GetServiceId()
|
||||||
|
a.Timestamp = serviceLog.GetTimestamp().AsTime()
|
||||||
|
a.Method = serviceLog.GetMethod()
|
||||||
|
a.Host = serviceLog.GetHost()
|
||||||
|
a.Path = serviceLog.GetPath()
|
||||||
|
a.Duration = time.Duration(serviceLog.GetDurationMs()) * time.Millisecond
|
||||||
|
a.StatusCode = int(serviceLog.GetResponseCode())
|
||||||
|
a.UserId = serviceLog.GetUserId()
|
||||||
|
a.AuthMethodUsed = serviceLog.GetAuthMechanism()
|
||||||
|
a.AccountID = serviceLog.GetAccountId()
|
||||||
|
|
||||||
|
if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" {
|
||||||
|
if ip, err := netip.ParseAddr(sourceIP); err == nil {
|
||||||
|
a.GeoLocation.ConnectionIP = net.IP(ip.AsSlice())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !serviceLog.GetAuthSuccess() {
|
||||||
|
a.Reason = "Authentication failed"
|
||||||
|
} else if serviceLog.GetResponseCode() >= 400 {
|
||||||
|
a.Reason = "Request failed"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToAPIResponse converts an AccessLogEntry to the API ProxyAccessLog type
|
||||||
|
func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
|
||||||
|
var sourceIP *string
|
||||||
|
if a.GeoLocation.ConnectionIP != nil {
|
||||||
|
ip := a.GeoLocation.ConnectionIP.String()
|
||||||
|
sourceIP = &ip
|
||||||
|
}
|
||||||
|
|
||||||
|
var reason *string
|
||||||
|
if a.Reason != "" {
|
||||||
|
reason = &a.Reason
|
||||||
|
}
|
||||||
|
|
||||||
|
var userID *string
|
||||||
|
if a.UserId != "" {
|
||||||
|
userID = &a.UserId
|
||||||
|
}
|
||||||
|
|
||||||
|
var authMethod *string
|
||||||
|
if a.AuthMethodUsed != "" {
|
||||||
|
authMethod = &a.AuthMethodUsed
|
||||||
|
}
|
||||||
|
|
||||||
|
var countryCode *string
|
||||||
|
if a.GeoLocation.CountryCode != "" {
|
||||||
|
countryCode = &a.GeoLocation.CountryCode
|
||||||
|
}
|
||||||
|
|
||||||
|
var cityName *string
|
||||||
|
if a.GeoLocation.CityName != "" {
|
||||||
|
cityName = &a.GeoLocation.CityName
|
||||||
|
}
|
||||||
|
|
||||||
|
return &api.ProxyAccessLog{
|
||||||
|
Id: a.ID,
|
||||||
|
ServiceId: a.ServiceID,
|
||||||
|
Timestamp: a.Timestamp,
|
||||||
|
Method: a.Method,
|
||||||
|
Host: a.Host,
|
||||||
|
Path: a.Path,
|
||||||
|
DurationMs: int(a.Duration.Milliseconds()),
|
||||||
|
StatusCode: a.StatusCode,
|
||||||
|
SourceIp: sourceIP,
|
||||||
|
Reason: reason,
|
||||||
|
UserId: userID,
|
||||||
|
AuthMethodUsed: authMethod,
|
||||||
|
CountryCode: countryCode,
|
||||||
|
CityName: cityName,
|
||||||
|
}
|
||||||
|
}
|
||||||
109
management/internals/modules/reverseproxy/accesslogs/filter.go
Normal file
109
management/internals/modules/reverseproxy/accesslogs/filter.go
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
package accesslogs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultPageSize is the default number of records per page
|
||||||
|
DefaultPageSize = 50
|
||||||
|
// MaxPageSize is the maximum number of records allowed per page
|
||||||
|
MaxPageSize = 100
|
||||||
|
)
|
||||||
|
|
||||||
|
// AccessLogFilter holds pagination and filtering parameters for access logs
|
||||||
|
type AccessLogFilter struct {
|
||||||
|
// Page is the current page number (1-indexed)
|
||||||
|
Page int
|
||||||
|
// PageSize is the number of records per page
|
||||||
|
PageSize int
|
||||||
|
|
||||||
|
// Filtering parameters
|
||||||
|
Search *string // General search across log ID, host, path, source IP, and user fields
|
||||||
|
SourceIP *string // Filter by source IP address
|
||||||
|
Host *string // Filter by host header
|
||||||
|
Path *string // Filter by request path (supports LIKE pattern)
|
||||||
|
UserID *string // Filter by authenticated user ID
|
||||||
|
UserEmail *string // Filter by user email (requires user lookup)
|
||||||
|
UserName *string // Filter by user name (requires user lookup)
|
||||||
|
Method *string // Filter by HTTP method
|
||||||
|
Status *string // Filter by status: "success" (2xx/3xx) or "failed" (1xx/4xx/5xx)
|
||||||
|
StatusCode *int // Filter by HTTP status code
|
||||||
|
StartDate *time.Time // Filter by timestamp >= start_date
|
||||||
|
EndDate *time.Time // Filter by timestamp <= end_date
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseFromRequest parses pagination and filter parameters from HTTP request query parameters
|
||||||
|
func (f *AccessLogFilter) ParseFromRequest(r *http.Request) {
|
||||||
|
queryParams := r.URL.Query()
|
||||||
|
|
||||||
|
f.Page = parsePositiveInt(queryParams.Get("page"), 1)
|
||||||
|
f.PageSize = min(parsePositiveInt(queryParams.Get("page_size"), DefaultPageSize), MaxPageSize)
|
||||||
|
|
||||||
|
f.Search = parseOptionalString(queryParams.Get("search"))
|
||||||
|
f.SourceIP = parseOptionalString(queryParams.Get("source_ip"))
|
||||||
|
f.Host = parseOptionalString(queryParams.Get("host"))
|
||||||
|
f.Path = parseOptionalString(queryParams.Get("path"))
|
||||||
|
f.UserID = parseOptionalString(queryParams.Get("user_id"))
|
||||||
|
f.UserEmail = parseOptionalString(queryParams.Get("user_email"))
|
||||||
|
f.UserName = parseOptionalString(queryParams.Get("user_name"))
|
||||||
|
f.Method = parseOptionalString(queryParams.Get("method"))
|
||||||
|
f.Status = parseOptionalString(queryParams.Get("status"))
|
||||||
|
f.StatusCode = parseOptionalInt(queryParams.Get("status_code"))
|
||||||
|
f.StartDate = parseOptionalRFC3339(queryParams.Get("start_date"))
|
||||||
|
f.EndDate = parseOptionalRFC3339(queryParams.Get("end_date"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// parsePositiveInt parses a positive integer from a string, returning defaultValue if invalid
|
||||||
|
func parsePositiveInt(s string, defaultValue int) int {
|
||||||
|
if s == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
if val, err := strconv.Atoi(s); err == nil && val > 0 {
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseOptionalString returns a pointer to the string if non-empty, otherwise nil
|
||||||
|
func parseOptionalString(s string) *string {
|
||||||
|
if s == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseOptionalInt parses an optional positive integer from a string
|
||||||
|
func parseOptionalInt(s string) *int {
|
||||||
|
if s == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if val, err := strconv.Atoi(s); err == nil && val > 0 {
|
||||||
|
v := val
|
||||||
|
return &v
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseOptionalRFC3339 parses an optional RFC3339 timestamp from a string
|
||||||
|
func parseOptionalRFC3339(s string) *time.Time {
|
||||||
|
if s == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||||
|
return &t
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOffset calculates the database offset for pagination
|
||||||
|
func (f *AccessLogFilter) GetOffset() int {
|
||||||
|
return (f.Page - 1) * f.PageSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLimit returns the page size for database queries
|
||||||
|
func (f *AccessLogFilter) GetLimit() int {
|
||||||
|
return f.PageSize
|
||||||
|
}
|
||||||
@@ -0,0 +1,371 @@
|
|||||||
|
package accesslogs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAccessLogFilter_ParseFromRequest(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
queryParams map[string]string
|
||||||
|
expectedPage int
|
||||||
|
expectedPageSize int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default values when no params provided",
|
||||||
|
queryParams: map[string]string{},
|
||||||
|
expectedPage: 1,
|
||||||
|
expectedPageSize: DefaultPageSize,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid page and page_size",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "2",
|
||||||
|
"page_size": "25",
|
||||||
|
},
|
||||||
|
expectedPage: 2,
|
||||||
|
expectedPageSize: 25,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "page_size exceeds max, should cap at MaxPageSize",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "1",
|
||||||
|
"page_size": "200",
|
||||||
|
},
|
||||||
|
expectedPage: 1,
|
||||||
|
expectedPageSize: MaxPageSize,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid page number, should use default",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "invalid",
|
||||||
|
"page_size": "10",
|
||||||
|
},
|
||||||
|
expectedPage: 1,
|
||||||
|
expectedPageSize: 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid page_size, should use default",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "2",
|
||||||
|
"page_size": "invalid",
|
||||||
|
},
|
||||||
|
expectedPage: 2,
|
||||||
|
expectedPageSize: DefaultPageSize,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero page number, should use default",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "0",
|
||||||
|
"page_size": "10",
|
||||||
|
},
|
||||||
|
expectedPage: 1,
|
||||||
|
expectedPageSize: 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative page number, should use default",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "-1",
|
||||||
|
"page_size": "10",
|
||||||
|
},
|
||||||
|
expectedPage: 1,
|
||||||
|
expectedPageSize: 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero page_size, should use default",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "1",
|
||||||
|
"page_size": "0",
|
||||||
|
},
|
||||||
|
expectedPage: 1,
|
||||||
|
expectedPageSize: DefaultPageSize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
q := req.URL.Query()
|
||||||
|
for key, value := range tt.queryParams {
|
||||||
|
q.Set(key, value)
|
||||||
|
}
|
||||||
|
req.URL.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
filter := &AccessLogFilter{}
|
||||||
|
filter.ParseFromRequest(req)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedPage, filter.Page, "Page mismatch")
|
||||||
|
assert.Equal(t, tt.expectedPageSize, filter.PageSize, "PageSize mismatch")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogFilter_GetOffset(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
page int
|
||||||
|
pageSize int
|
||||||
|
expectedOffset int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "first page",
|
||||||
|
page: 1,
|
||||||
|
pageSize: 50,
|
||||||
|
expectedOffset: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "second page",
|
||||||
|
page: 2,
|
||||||
|
pageSize: 50,
|
||||||
|
expectedOffset: 50,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "third page with page size 25",
|
||||||
|
page: 3,
|
||||||
|
pageSize: 25,
|
||||||
|
expectedOffset: 50,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "page 10 with page size 10",
|
||||||
|
page: 10,
|
||||||
|
pageSize: 10,
|
||||||
|
expectedOffset: 90,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
filter := &AccessLogFilter{
|
||||||
|
Page: tt.page,
|
||||||
|
PageSize: tt.pageSize,
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := filter.GetOffset()
|
||||||
|
assert.Equal(t, tt.expectedOffset, offset)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogFilter_GetLimit(t *testing.T) {
|
||||||
|
filter := &AccessLogFilter{
|
||||||
|
Page: 2,
|
||||||
|
PageSize: 25,
|
||||||
|
}
|
||||||
|
|
||||||
|
limit := filter.GetLimit()
|
||||||
|
assert.Equal(t, 25, limit, "GetLimit should return PageSize")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogFilter_ParseFromRequest_FilterParams(t *testing.T) {
|
||||||
|
startDate := "2024-01-15T10:30:00Z"
|
||||||
|
endDate := "2024-01-16T15:45:00Z"
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
q := req.URL.Query()
|
||||||
|
q.Set("search", "test query")
|
||||||
|
q.Set("source_ip", "192.168.1.1")
|
||||||
|
q.Set("host", "example.com")
|
||||||
|
q.Set("path", "/api/users")
|
||||||
|
q.Set("user_id", "user123")
|
||||||
|
q.Set("user_email", "user@example.com")
|
||||||
|
q.Set("user_name", "John Doe")
|
||||||
|
q.Set("method", "GET")
|
||||||
|
q.Set("status", "success")
|
||||||
|
q.Set("status_code", "200")
|
||||||
|
q.Set("start_date", startDate)
|
||||||
|
q.Set("end_date", endDate)
|
||||||
|
req.URL.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
filter := &AccessLogFilter{}
|
||||||
|
filter.ParseFromRequest(req)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.Search)
|
||||||
|
assert.Equal(t, "test query", *filter.Search)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.SourceIP)
|
||||||
|
assert.Equal(t, "192.168.1.1", *filter.SourceIP)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.Host)
|
||||||
|
assert.Equal(t, "example.com", *filter.Host)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.Path)
|
||||||
|
assert.Equal(t, "/api/users", *filter.Path)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.UserID)
|
||||||
|
assert.Equal(t, "user123", *filter.UserID)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.UserEmail)
|
||||||
|
assert.Equal(t, "user@example.com", *filter.UserEmail)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.UserName)
|
||||||
|
assert.Equal(t, "John Doe", *filter.UserName)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.Method)
|
||||||
|
assert.Equal(t, "GET", *filter.Method)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.Status)
|
||||||
|
assert.Equal(t, "success", *filter.Status)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.StatusCode)
|
||||||
|
assert.Equal(t, 200, *filter.StatusCode)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.StartDate)
|
||||||
|
expectedStart, _ := time.Parse(time.RFC3339, startDate)
|
||||||
|
assert.Equal(t, expectedStart, *filter.StartDate)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.EndDate)
|
||||||
|
expectedEnd, _ := time.Parse(time.RFC3339, endDate)
|
||||||
|
assert.Equal(t, expectedEnd, *filter.EndDate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogFilter_ParseFromRequest_EmptyFilters(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
|
||||||
|
filter := &AccessLogFilter{}
|
||||||
|
filter.ParseFromRequest(req)
|
||||||
|
|
||||||
|
assert.Nil(t, filter.Search)
|
||||||
|
assert.Nil(t, filter.SourceIP)
|
||||||
|
assert.Nil(t, filter.Host)
|
||||||
|
assert.Nil(t, filter.Path)
|
||||||
|
assert.Nil(t, filter.UserID)
|
||||||
|
assert.Nil(t, filter.UserEmail)
|
||||||
|
assert.Nil(t, filter.UserName)
|
||||||
|
assert.Nil(t, filter.Method)
|
||||||
|
assert.Nil(t, filter.Status)
|
||||||
|
assert.Nil(t, filter.StatusCode)
|
||||||
|
assert.Nil(t, filter.StartDate)
|
||||||
|
assert.Nil(t, filter.EndDate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogFilter_ParseFromRequest_InvalidFilters(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
q := req.URL.Query()
|
||||||
|
q.Set("status_code", "invalid")
|
||||||
|
q.Set("start_date", "not-a-date")
|
||||||
|
q.Set("end_date", "2024-99-99")
|
||||||
|
req.URL.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
filter := &AccessLogFilter{}
|
||||||
|
filter.ParseFromRequest(req)
|
||||||
|
|
||||||
|
assert.Nil(t, filter.StatusCode, "invalid status_code should be nil")
|
||||||
|
assert.Nil(t, filter.StartDate, "invalid start_date should be nil")
|
||||||
|
assert.Nil(t, filter.EndDate, "invalid end_date should be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePositiveInt(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
defaultValue int
|
||||||
|
expected int
|
||||||
|
}{
|
||||||
|
{"empty string", "", 10, 10},
|
||||||
|
{"valid positive int", "25", 10, 25},
|
||||||
|
{"zero", "0", 10, 10},
|
||||||
|
{"negative", "-5", 10, 10},
|
||||||
|
{"invalid string", "abc", 10, 10},
|
||||||
|
{"float", "3.14", 10, 10},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := parsePositiveInt(tt.input, tt.defaultValue)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseOptionalString(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected *string
|
||||||
|
}{
|
||||||
|
{"empty string", "", nil},
|
||||||
|
{"valid string", "hello", strPtr("hello")},
|
||||||
|
{"whitespace", " ", strPtr(" ")},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := parseOptionalString(tt.input)
|
||||||
|
if tt.expected == nil {
|
||||||
|
assert.Nil(t, result)
|
||||||
|
} else {
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.Equal(t, *tt.expected, *result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseOptionalInt(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected *int
|
||||||
|
}{
|
||||||
|
{"empty string", "", nil},
|
||||||
|
{"valid positive int", "42", intPtr(42)},
|
||||||
|
{"zero", "0", nil},
|
||||||
|
{"negative", "-10", nil},
|
||||||
|
{"invalid string", "abc", nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := parseOptionalInt(tt.input)
|
||||||
|
if tt.expected == nil {
|
||||||
|
assert.Nil(t, result)
|
||||||
|
} else {
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.Equal(t, *tt.expected, *result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseOptionalRFC3339(t *testing.T) {
|
||||||
|
validDate := "2024-01-15T10:30:00Z"
|
||||||
|
expectedTime, _ := time.Parse(time.RFC3339, validDate)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected *time.Time
|
||||||
|
}{
|
||||||
|
{"empty string", "", nil},
|
||||||
|
{"valid RFC3339", validDate, &expectedTime},
|
||||||
|
{"invalid format", "2024-01-15", nil},
|
||||||
|
{"invalid date", "not-a-date", nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := parseOptionalRFC3339(tt.input)
|
||||||
|
if tt.expected == nil {
|
||||||
|
assert.Nil(t, result)
|
||||||
|
} else {
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.Equal(t, *tt.expected, *result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions for creating pointers
|
||||||
|
func strPtr(s string) *string {
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
|
func intPtr(i int) *int {
|
||||||
|
return &i
|
||||||
|
}
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
package accesslogs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager interface {
|
||||||
|
SaveAccessLog(ctx context.Context, proxyLog *AccessLogEntry) error
|
||||||
|
GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *AccessLogFilter) ([]*AccessLogEntry, int64, error)
|
||||||
|
}
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
type handler struct {
|
||||||
|
manager accesslogs.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager) {
|
||||||
|
h := &handler{
|
||||||
|
manager: manager,
|
||||||
|
}
|
||||||
|
|
||||||
|
router.HandleFunc("/events/proxy", h.getAccessLogs).Methods("GET", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var filter accesslogs.AccessLogFilter
|
||||||
|
filter.ParseFromRequest(r)
|
||||||
|
|
||||||
|
logs, totalCount, err := h.manager.GetAllAccessLogs(r.Context(), userAuth.AccountId, userAuth.UserId, &filter)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
apiLogs := make([]api.ProxyAccessLog, 0, len(logs))
|
||||||
|
for _, log := range logs {
|
||||||
|
apiLogs = append(apiLogs, *log.ToAPIResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
response := &api.ProxyAccessLogsResponse{
|
||||||
|
Data: apiLogs,
|
||||||
|
Page: filter.Page,
|
||||||
|
PageSize: filter.PageSize,
|
||||||
|
TotalRecords: int(totalCount),
|
||||||
|
TotalPages: getTotalPageCount(int(totalCount), filter.PageSize),
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTotalPageCount calculates the total number of pages
|
||||||
|
func getTotalPageCount(totalCount, pageSize int) int {
|
||||||
|
if pageSize <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return (totalCount + pageSize - 1) / pageSize
|
||||||
|
}
|
||||||
@@ -0,0 +1,108 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type managerImpl struct {
|
||||||
|
store store.Store
|
||||||
|
permissionsManager permissions.Manager
|
||||||
|
geo geolocation.Geolocation
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(store store.Store, permissionsManager permissions.Manager, geo geolocation.Geolocation) accesslogs.Manager {
|
||||||
|
return &managerImpl{
|
||||||
|
store: store,
|
||||||
|
permissionsManager: permissionsManager,
|
||||||
|
geo: geo,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveAccessLog saves an access log entry to the database after enriching it
|
||||||
|
func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.AccessLogEntry) error {
|
||||||
|
if m.geo != nil && logEntry.GeoLocation.ConnectionIP != nil {
|
||||||
|
location, err := m.geo.Lookup(logEntry.GeoLocation.ConnectionIP)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to get location for access log source IP [%s]: %v", logEntry.GeoLocation.ConnectionIP.String(), err)
|
||||||
|
} else {
|
||||||
|
logEntry.GeoLocation.CountryCode = location.Country.ISOCode
|
||||||
|
logEntry.GeoLocation.CityName = location.City.Names.En
|
||||||
|
logEntry.GeoLocation.GeoNameID = location.City.GeonameID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.store.CreateAccessLog(ctx, logEntry); err != nil {
|
||||||
|
log.WithContext(ctx).WithFields(log.Fields{
|
||||||
|
"service_id": logEntry.ServiceID,
|
||||||
|
"method": logEntry.Method,
|
||||||
|
"host": logEntry.Host,
|
||||||
|
"path": logEntry.Path,
|
||||||
|
"status": logEntry.StatusCode,
|
||||||
|
}).Errorf("failed to save access log: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllAccessLogs retrieves access logs for an account with pagination and filtering
|
||||||
|
func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, 0, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.resolveUserFilters(ctx, accountID, filter); err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to resolve user filters: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logs, totalCount, err := m.store.GetAccountAccessLogs(ctx, store.LockingStrengthNone, accountID, *filter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return logs, totalCount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveUserFilters converts user email/name filters to user ID filter
|
||||||
|
func (m *managerImpl) resolveUserFilters(ctx context.Context, accountID string, filter *accesslogs.AccessLogFilter) error {
|
||||||
|
if filter.UserEmail == nil && filter.UserName == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
users, err := m.store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var matchingUserIDs []string
|
||||||
|
for _, user := range users {
|
||||||
|
if filter.UserEmail != nil && strings.Contains(strings.ToLower(user.Email), strings.ToLower(*filter.UserEmail)) {
|
||||||
|
matchingUserIDs = append(matchingUserIDs, user.Id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if filter.UserName != nil && strings.Contains(strings.ToLower(user.Name), strings.ToLower(*filter.UserName)) {
|
||||||
|
matchingUserIDs = append(matchingUserIDs, user.Id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(matchingUserIDs) > 0 {
|
||||||
|
filter.UserID = &matchingUserIDs[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
17
management/internals/modules/reverseproxy/domain/domain.go
Normal file
17
management/internals/modules/reverseproxy/domain/domain.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
type Type string
|
||||||
|
|
||||||
|
const (
|
||||||
|
TypeFree Type = "free"
|
||||||
|
TypeCustom Type = "custom"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Domain struct {
|
||||||
|
ID string `gorm:"unique;primaryKey;autoIncrement"`
|
||||||
|
Domain string `gorm:"unique"` // Domain records must be unique, this avoids domain reuse across accounts.
|
||||||
|
AccountID string `gorm:"index"`
|
||||||
|
TargetCluster string // The proxy cluster this domain should be validated against
|
||||||
|
Type Type `gorm:"-"`
|
||||||
|
Validated bool
|
||||||
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager interface {
|
||||||
|
GetDomains(ctx context.Context, accountID, userID string) ([]*Domain, error)
|
||||||
|
CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*Domain, error)
|
||||||
|
DeleteDomain(ctx context.Context, accountID, userID, domainID string) error
|
||||||
|
ValidateDomain(ctx context.Context, accountID, userID, domainID string)
|
||||||
|
}
|
||||||
136
management/internals/modules/reverseproxy/domain/manager/api.go
Normal file
136
management/internals/modules/reverseproxy/domain/manager/api.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||||
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type handler struct {
|
||||||
|
manager Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterEndpoints(router *mux.Router, manager Manager) {
|
||||||
|
h := &handler{
|
||||||
|
manager: manager,
|
||||||
|
}
|
||||||
|
|
||||||
|
router.HandleFunc("/domains", h.getAllDomains).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/domains", h.createCustomDomain).Methods("POST", "OPTIONS")
|
||||||
|
router.HandleFunc("/domains/{domainId}", h.deleteCustomDomain).Methods("DELETE", "OPTIONS")
|
||||||
|
router.HandleFunc("/domains/{domainId}/validate", h.triggerCustomDomainValidation).Methods("GET", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType {
|
||||||
|
switch t {
|
||||||
|
case domain.TypeCustom:
|
||||||
|
return api.ReverseProxyDomainTypeCustom
|
||||||
|
case domain.TypeFree:
|
||||||
|
return api.ReverseProxyDomainTypeFree
|
||||||
|
}
|
||||||
|
// By default return as a "free" domain as that is more restrictive.
|
||||||
|
// TODO: is this correct?
|
||||||
|
return api.ReverseProxyDomainTypeFree
|
||||||
|
}
|
||||||
|
|
||||||
|
func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
|
||||||
|
resp := api.ReverseProxyDomain{
|
||||||
|
Domain: d.Domain,
|
||||||
|
Id: d.ID,
|
||||||
|
Type: domainTypeToApi(d.Type),
|
||||||
|
Validated: d.Validated,
|
||||||
|
}
|
||||||
|
if d.TargetCluster != "" {
|
||||||
|
resp.TargetCluster = &d.TargetCluster
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
domains, err := h.manager.GetDomains(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ret := make([]api.ReverseProxyDomain, 0)
|
||||||
|
for _, d := range domains {
|
||||||
|
ret = append(ret, domainToApi(d))
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req api.PostApiReverseProxiesDomainsJSONRequestBody
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
domain, err := h.manager.CreateDomain(r.Context(), userAuth.AccountId, userAuth.UserId, req.Domain, req.TargetCluster)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, domainToApi(domain))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
domainID := mux.Vars(r)["domainId"]
|
||||||
|
if domainID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.manager.DeleteDomain(r.Context(), userAuth.AccountId, userAuth.UserId, domainID); err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
domainID := mux.Vars(r)["domainId"]
|
||||||
|
if domainID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go h.manager.ValidateDomain(r.Context(), userAuth.AccountId, userAuth.UserId, domainID)
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusAccepted)
|
||||||
|
}
|
||||||
@@ -0,0 +1,279 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type store interface {
|
||||||
|
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
|
||||||
|
|
||||||
|
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)
|
||||||
|
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
|
||||||
|
ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error)
|
||||||
|
CreateCustomDomain(ctx context.Context, accountID string, domainName string, targetCluster string, validated bool) (*domain.Domain, error)
|
||||||
|
UpdateCustomDomain(ctx context.Context, accountID string, d *domain.Domain) (*domain.Domain, error)
|
||||||
|
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type proxyURLProvider interface {
|
||||||
|
GetConnectedProxyURLs() []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
store store
|
||||||
|
validator domain.Validator
|
||||||
|
proxyURLProvider proxyURLProvider
|
||||||
|
permissionsManager permissions.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(store store, proxyURLProvider proxyURLProvider, permissionsManager permissions.Manager) Manager {
|
||||||
|
return Manager{
|
||||||
|
store: store,
|
||||||
|
proxyURLProvider: proxyURLProvider,
|
||||||
|
validator: domain.Validator{
|
||||||
|
Resolver: net.DefaultResolver,
|
||||||
|
},
|
||||||
|
permissionsManager: permissionsManager,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
domains, err := m.store.ListCustomDomains(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list custom domains: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ret []*domain.Domain
|
||||||
|
|
||||||
|
// Add connected proxy clusters as free domains.
|
||||||
|
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
|
||||||
|
allowList := m.proxyURLAllowList()
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"proxyAllowList": allowList,
|
||||||
|
}).Debug("getting domains with proxy allow list")
|
||||||
|
|
||||||
|
for _, cluster := range allowList {
|
||||||
|
ret = append(ret, &domain.Domain{
|
||||||
|
Domain: cluster,
|
||||||
|
AccountID: accountID,
|
||||||
|
Type: domain.TypeFree,
|
||||||
|
Validated: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add custom domains.
|
||||||
|
for _, d := range domains {
|
||||||
|
ret = append(ret, &domain.Domain{
|
||||||
|
ID: d.ID,
|
||||||
|
Domain: d.Domain,
|
||||||
|
AccountID: accountID,
|
||||||
|
TargetCluster: d.TargetCluster,
|
||||||
|
Type: domain.TypeCustom,
|
||||||
|
Validated: d.Validated,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*domain.Domain, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the target cluster is in the available clusters
|
||||||
|
allowList := m.proxyURLAllowList()
|
||||||
|
clusterValid := false
|
||||||
|
for _, cluster := range allowList {
|
||||||
|
if cluster == targetCluster {
|
||||||
|
clusterValid = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !clusterValid {
|
||||||
|
return nil, fmt.Errorf("target cluster %s is not available", targetCluster)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt an initial validation against the specified cluster only
|
||||||
|
var validated bool
|
||||||
|
if m.validator.IsValid(ctx, domainName, []string{targetCluster}) {
|
||||||
|
validated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := m.store.CreateCustomDomain(ctx, accountID, domainName, targetCluster, validated)
|
||||||
|
if err != nil {
|
||||||
|
return d, fmt.Errorf("create domain in store: %w", err)
|
||||||
|
}
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID string) error {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||||
|
if err != nil {
|
||||||
|
return status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.store.DeleteCustomDomain(ctx, accountID, domainID); err != nil {
|
||||||
|
// TODO: check for "no records" type error. Because that is a success condition.
|
||||||
|
return fmt.Errorf("delete domain from store: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID string) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
}).WithError(err).Error("validate domain")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
}).WithError(err).Error("validate domain")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
}).Info("starting domain validation")
|
||||||
|
|
||||||
|
d, err := m.store.GetCustomDomain(context.Background(), accountID, domainID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
}).WithError(err).Error("get custom domain from store")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate only against the domain's target cluster
|
||||||
|
targetCluster := d.TargetCluster
|
||||||
|
if targetCluster == "" {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
"domain": d.Domain,
|
||||||
|
}).Warn("domain has no target cluster set, skipping validation")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
"domain": d.Domain,
|
||||||
|
"targetCluster": targetCluster,
|
||||||
|
}).Info("validating domain against target cluster")
|
||||||
|
|
||||||
|
if m.validator.IsValid(context.Background(), d.Domain, []string{targetCluster}) {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
"domain": d.Domain,
|
||||||
|
}).Info("domain validated successfully")
|
||||||
|
d.Validated = true
|
||||||
|
if _, err := m.store.UpdateCustomDomain(context.Background(), accountID, d); err != nil {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
"domain": d.Domain,
|
||||||
|
}).WithError(err).Error("update custom domain in store")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
"domain": d.Domain,
|
||||||
|
"targetCluster": targetCluster,
|
||||||
|
}).Warn("domain validation failed - CNAME does not match target cluster")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// proxyURLAllowList retrieves a list of currently connected proxies and
|
||||||
|
// their URLs
|
||||||
|
func (m Manager) proxyURLAllowList() []string {
|
||||||
|
var reverseProxyAddresses []string
|
||||||
|
if m.proxyURLProvider != nil {
|
||||||
|
reverseProxyAddresses = m.proxyURLProvider.GetConnectedProxyURLs()
|
||||||
|
}
|
||||||
|
return reverseProxyAddresses
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeriveClusterFromDomain determines the proxy cluster for a given domain.
|
||||||
|
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
|
||||||
|
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
|
||||||
|
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
|
||||||
|
allowList := m.proxyURLAllowList()
|
||||||
|
if len(allowList) == 0 {
|
||||||
|
return "", fmt.Errorf("no proxy clusters available")
|
||||||
|
}
|
||||||
|
|
||||||
|
if cluster, ok := ExtractClusterFromFreeDomain(domain, allowList); ok {
|
||||||
|
return cluster, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
customDomains, err := m.store.ListCustomDomains(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("list custom domains: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
targetCluster, valid := extractClusterFromCustomDomains(domain, customDomains)
|
||||||
|
if valid {
|
||||||
|
return targetCluster, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractClusterFromCustomDomains(domain string, customDomains []*domain.Domain) (string, bool) {
|
||||||
|
for _, customDomain := range customDomains {
|
||||||
|
if strings.HasSuffix(domain, "."+customDomain.Domain) {
|
||||||
|
return customDomain.TargetCluster, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractClusterFromFreeDomain extracts the cluster address from a free domain.
|
||||||
|
// Free domains have the format: <name>.<nonce>.<cluster> (e.g., myapp.abc123.eu.proxy.netbird.io)
|
||||||
|
// It matches the domain suffix against available clusters and returns the matching cluster.
|
||||||
|
func ExtractClusterFromFreeDomain(domain string, availableClusters []string) (string, bool) {
|
||||||
|
for _, cluster := range availableClusters {
|
||||||
|
if strings.HasSuffix(domain, "."+cluster) {
|
||||||
|
return cluster, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type resolver interface {
|
||||||
|
LookupCNAME(context.Context, string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Validator struct {
|
||||||
|
Resolver resolver
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewValidator initializes a validator with a specific DNS Resolver.
|
||||||
|
// If a Validator is used without specifying a Resolver, then it will
|
||||||
|
// use the net.DefaultResolver.
|
||||||
|
func NewValidator(resolver resolver) *Validator {
|
||||||
|
return &Validator{
|
||||||
|
Resolver: resolver,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValid looks up the CNAME record for the passed domain with a prefix
|
||||||
|
// and compares it against the acceptable domains.
|
||||||
|
// If the returned CNAME matches any accepted domain, it will return true,
|
||||||
|
// otherwise, including in the event of a DNS error, it will return false.
|
||||||
|
// The comparison is very simple, so wildcards will not match if included
|
||||||
|
// in the acceptable domain list.
|
||||||
|
func (v *Validator) IsValid(ctx context.Context, domain string, accept []string) bool {
|
||||||
|
_, valid := v.ValidateWithCluster(ctx, domain, accept)
|
||||||
|
return valid
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateWithCluster validates a custom domain and returns the matched cluster address.
|
||||||
|
// Returns the cluster address and true if valid, or empty string and false if invalid.
|
||||||
|
func (v *Validator) ValidateWithCluster(ctx context.Context, domain string, accept []string) (string, bool) {
|
||||||
|
if v.Resolver == nil {
|
||||||
|
v.Resolver = net.DefaultResolver
|
||||||
|
}
|
||||||
|
|
||||||
|
lookupDomain := "validation." + domain
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"domain": domain,
|
||||||
|
"lookupDomain": lookupDomain,
|
||||||
|
"acceptList": accept,
|
||||||
|
}).Debug("looking up CNAME for domain validation")
|
||||||
|
|
||||||
|
cname, err := v.Resolver.LookupCNAME(ctx, lookupDomain)
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"domain": domain,
|
||||||
|
"lookupDomain": lookupDomain,
|
||||||
|
}).WithError(err).Warn("CNAME lookup failed for domain validation")
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
nakedCNAME := strings.TrimSuffix(cname, ".")
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"domain": domain,
|
||||||
|
"cname": cname,
|
||||||
|
"nakedCNAME": nakedCNAME,
|
||||||
|
"acceptList": accept,
|
||||||
|
}).Debug("CNAME lookup result for domain validation")
|
||||||
|
|
||||||
|
for _, acceptDomain := range accept {
|
||||||
|
normalizedAccept := strings.TrimSuffix(acceptDomain, ".")
|
||||||
|
if nakedCNAME == normalizedAccept {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"domain": domain,
|
||||||
|
"cname": nakedCNAME,
|
||||||
|
"cluster": acceptDomain,
|
||||||
|
}).Info("domain CNAME matched cluster")
|
||||||
|
return acceptDomain, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"domain": domain,
|
||||||
|
"cname": nakedCNAME,
|
||||||
|
"acceptList": accept,
|
||||||
|
}).Warn("domain CNAME does not match any accepted cluster")
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
package domain_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
type resolver struct {
|
||||||
|
CNAME string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r resolver) LookupCNAME(_ context.Context, _ string) (string, error) {
|
||||||
|
return r.CNAME, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsValid(t *testing.T) {
|
||||||
|
tests := map[string]struct {
|
||||||
|
resolver interface {
|
||||||
|
LookupCNAME(context.Context, string) (string, error)
|
||||||
|
}
|
||||||
|
domain string
|
||||||
|
accept []string
|
||||||
|
expect bool
|
||||||
|
}{
|
||||||
|
"match": {
|
||||||
|
resolver: resolver{"bar.example.com."}, // Including trailing "." in response.
|
||||||
|
domain: "foo.example.com",
|
||||||
|
accept: []string{"bar.example.com"},
|
||||||
|
expect: true,
|
||||||
|
},
|
||||||
|
"no match": {
|
||||||
|
resolver: resolver{"invalid"},
|
||||||
|
domain: "foo.example.com",
|
||||||
|
accept: []string{"bar.example.com"},
|
||||||
|
expect: false,
|
||||||
|
},
|
||||||
|
"accept trailing dot": {
|
||||||
|
resolver: resolver{"bar.example.com."},
|
||||||
|
domain: "foo.example.com",
|
||||||
|
accept: []string{"bar.example.com."}, // Including trailing "." in accept.
|
||||||
|
expect: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, test := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
validator := domain.NewValidator(test.resolver)
|
||||||
|
actual := validator.IsValid(t.Context(), test.domain, test.accept)
|
||||||
|
if test.expect != actual {
|
||||||
|
t.Errorf("Incorrect return value:\nexpect: %v\nactual: %v", test.expect, actual)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
23
management/internals/modules/reverseproxy/interface.go
Normal file
23
management/internals/modules/reverseproxy/interface.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package reverseproxy
|
||||||
|
|
||||||
|
//go:generate go run github.com/golang/mock/mockgen -package reverseproxy -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager interface {
|
||||||
|
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
|
||||||
|
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
|
||||||
|
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
||||||
|
UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
||||||
|
DeleteService(ctx context.Context, accountID, userID, serviceID string) error
|
||||||
|
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
|
||||||
|
SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error
|
||||||
|
ReloadAllServicesForAccount(ctx context.Context, accountID string) error
|
||||||
|
ReloadService(ctx context.Context, accountID, serviceID string) error
|
||||||
|
GetGlobalServices(ctx context.Context) ([]*Service, error)
|
||||||
|
GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error)
|
||||||
|
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
|
||||||
|
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
|
||||||
|
}
|
||||||
225
management/internals/modules/reverseproxy/interface_mock.go
Normal file
225
management/internals/modules/reverseproxy/interface_mock.go
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: ./interface.go
|
||||||
|
|
||||||
|
// Package reverseproxy is a generated GoMock package.
|
||||||
|
package reverseproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
context "context"
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockManager is a mock of Manager interface.
|
||||||
|
type MockManager struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockManagerMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockManagerMockRecorder is the mock recorder for MockManager.
|
||||||
|
type MockManagerMockRecorder struct {
|
||||||
|
mock *MockManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockManager creates a new mock instance.
|
||||||
|
func NewMockManager(ctrl *gomock.Controller) *MockManager {
|
||||||
|
mock := &MockManager{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockManagerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockManager) EXPECT() *MockManagerMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateService mocks base method.
|
||||||
|
func (m *MockManager) CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "CreateService", ctx, accountID, userID, service)
|
||||||
|
ret0, _ := ret[0].(*Service)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateService indicates an expected call of CreateService.
|
||||||
|
func (mr *MockManagerMockRecorder) CreateService(ctx, accountID, userID, service interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteService mocks base method.
|
||||||
|
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "DeleteService", ctx, accountID, userID, serviceID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteService indicates an expected call of DeleteService.
|
||||||
|
func (mr *MockManagerMockRecorder) DeleteService(ctx, accountID, userID, serviceID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteService", reflect.TypeOf((*MockManager)(nil).DeleteService), ctx, accountID, userID, serviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountServices mocks base method.
|
||||||
|
func (m *MockManager) GetAccountServices(ctx context.Context, accountID string) ([]*Service, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetAccountServices", ctx, accountID)
|
||||||
|
ret0, _ := ret[0].([]*Service)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountServices indicates an expected call of GetAccountServices.
|
||||||
|
func (mr *MockManagerMockRecorder) GetAccountServices(ctx, accountID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockManager)(nil).GetAccountServices), ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllServices mocks base method.
|
||||||
|
func (m *MockManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetAllServices", ctx, accountID, userID)
|
||||||
|
ret0, _ := ret[0].([]*Service)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllServices indicates an expected call of GetAllServices.
|
||||||
|
func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGlobalServices mocks base method.
|
||||||
|
func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetGlobalServices", ctx)
|
||||||
|
ret0, _ := ret[0].([]*Service)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGlobalServices indicates an expected call of GetGlobalServices.
|
||||||
|
func (mr *MockManagerMockRecorder) GetGlobalServices(ctx interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGlobalServices", reflect.TypeOf((*MockManager)(nil).GetGlobalServices), ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetService mocks base method.
|
||||||
|
func (m *MockManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetService", ctx, accountID, userID, serviceID)
|
||||||
|
ret0, _ := ret[0].(*Service)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetService indicates an expected call of GetService.
|
||||||
|
func (mr *MockManagerMockRecorder) GetService(ctx, accountID, userID, serviceID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetService", reflect.TypeOf((*MockManager)(nil).GetService), ctx, accountID, userID, serviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetServiceByID mocks base method.
|
||||||
|
func (m *MockManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetServiceByID", ctx, accountID, serviceID)
|
||||||
|
ret0, _ := ret[0].(*Service)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetServiceByID indicates an expected call of GetServiceByID.
|
||||||
|
func (mr *MockManagerMockRecorder) GetServiceByID(ctx, accountID, serviceID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByID", reflect.TypeOf((*MockManager)(nil).GetServiceByID), ctx, accountID, serviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetServiceIDByTargetID mocks base method.
|
||||||
|
func (m *MockManager) GetServiceIDByTargetID(ctx context.Context, accountID, resourceID string) (string, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetServiceIDByTargetID", ctx, accountID, resourceID)
|
||||||
|
ret0, _ := ret[0].(string)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetServiceIDByTargetID indicates an expected call of GetServiceIDByTargetID.
|
||||||
|
func (mr *MockManagerMockRecorder) GetServiceIDByTargetID(ctx, accountID, resourceID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceIDByTargetID", reflect.TypeOf((*MockManager)(nil).GetServiceIDByTargetID), ctx, accountID, resourceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReloadAllServicesForAccount mocks base method.
|
||||||
|
func (m *MockManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "ReloadAllServicesForAccount", ctx, accountID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReloadAllServicesForAccount indicates an expected call of ReloadAllServicesForAccount.
|
||||||
|
func (mr *MockManagerMockRecorder) ReloadAllServicesForAccount(ctx, accountID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadAllServicesForAccount", reflect.TypeOf((*MockManager)(nil).ReloadAllServicesForAccount), ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReloadService mocks base method.
|
||||||
|
func (m *MockManager) ReloadService(ctx context.Context, accountID, serviceID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "ReloadService", ctx, accountID, serviceID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReloadService indicates an expected call of ReloadService.
|
||||||
|
func (mr *MockManagerMockRecorder) ReloadService(ctx, accountID, serviceID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadService", reflect.TypeOf((*MockManager)(nil).ReloadService), ctx, accountID, serviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCertificateIssuedAt mocks base method.
|
||||||
|
func (m *MockManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "SetCertificateIssuedAt", ctx, accountID, serviceID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCertificateIssuedAt indicates an expected call of SetCertificateIssuedAt.
|
||||||
|
func (mr *MockManagerMockRecorder) SetCertificateIssuedAt(ctx, accountID, serviceID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCertificateIssuedAt", reflect.TypeOf((*MockManager)(nil).SetCertificateIssuedAt), ctx, accountID, serviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetStatus mocks base method.
|
||||||
|
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetStatus indicates an expected call of SetStatus.
|
||||||
|
func (mr *MockManagerMockRecorder) SetStatus(ctx, accountID, serviceID, status interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetStatus", reflect.TypeOf((*MockManager)(nil).SetStatus), ctx, accountID, serviceID, status)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateService mocks base method.
|
||||||
|
func (m *MockManager) UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "UpdateService", ctx, accountID, userID, service)
|
||||||
|
ret0, _ := ret[0].(*Service)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateService indicates an expected call of UpdateService.
|
||||||
|
func (mr *MockManagerMockRecorder) UpdateService(ctx, accountID, userID, service interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateService", reflect.TypeOf((*MockManager)(nil).UpdateService), ctx, accountID, userID, service)
|
||||||
|
}
|
||||||
170
management/internals/modules/reverseproxy/manager/api.go
Normal file
170
management/internals/modules/reverseproxy/manager/api.go
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
|
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||||
|
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||||
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type handler struct {
|
||||||
|
manager reverseproxy.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterEndpoints registers all service HTTP endpoints.
|
||||||
|
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
|
||||||
|
h := &handler{
|
||||||
|
manager: manager,
|
||||||
|
}
|
||||||
|
|
||||||
|
domainRouter := router.PathPrefix("/reverse-proxies").Subrouter()
|
||||||
|
domainmanager.RegisterEndpoints(domainRouter, domainManager)
|
||||||
|
|
||||||
|
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
|
||||||
|
|
||||||
|
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
|
||||||
|
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.updateService).Methods("PUT", "OPTIONS")
|
||||||
|
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.deleteService).Methods("DELETE", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
allServices, err := h.manager.GetAllServices(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
apiServices := make([]*api.Service, 0, len(allServices))
|
||||||
|
for _, service := range allServices {
|
||||||
|
apiServices = append(apiServices, service.ToAPIResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, apiServices)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req api.ServiceRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
service := new(reverseproxy.Service)
|
||||||
|
service.FromAPIRequest(&req, userAuth.AccountId)
|
||||||
|
|
||||||
|
if err = service.Validate(); err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
createdService, err := h.manager.CreateService(r.Context(), userAuth.AccountId, userAuth.UserId, service)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, createdService.ToAPIResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) getService(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serviceID := mux.Vars(r)["serviceId"]
|
||||||
|
if serviceID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
service, err := h.manager.GetService(r.Context(), userAuth.AccountId, userAuth.UserId, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, service.ToAPIResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serviceID := mux.Vars(r)["serviceId"]
|
||||||
|
if serviceID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req api.ServiceRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
service := new(reverseproxy.Service)
|
||||||
|
service.ID = serviceID
|
||||||
|
service.FromAPIRequest(&req, userAuth.AccountId)
|
||||||
|
|
||||||
|
if err = service.Validate(); err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedService, err := h.manager.UpdateService(r.Context(), userAuth.AccountId, userAuth.UserId, service)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, updatedService.ToAPIResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serviceID := mux.Vars(r)["serviceId"]
|
||||||
|
if serviceID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.manager.DeleteService(r.Context(), userAuth.AccountId, userAuth.UserId, serviceID); err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||||
|
}
|
||||||
541
management/internals/modules/reverseproxy/manager/manager.go
Normal file
541
management/internals/modules/reverseproxy/manager/manager.go
Normal file
@@ -0,0 +1,541 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
const unknownHostPlaceholder = "unknown"
|
||||||
|
|
||||||
|
// ClusterDeriver derives the proxy cluster from a domain.
|
||||||
|
type ClusterDeriver interface {
|
||||||
|
DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type managerImpl struct {
|
||||||
|
store store.Store
|
||||||
|
accountManager account.Manager
|
||||||
|
permissionsManager permissions.Manager
|
||||||
|
proxyGRPCServer *nbgrpc.ProxyServiceServer
|
||||||
|
clusterDeriver ClusterDeriver
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new service manager.
|
||||||
|
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager {
|
||||||
|
return &managerImpl{
|
||||||
|
store: store,
|
||||||
|
accountManager: accountManager,
|
||||||
|
permissionsManager: permissionsManager,
|
||||||
|
proxyGRPCServer: proxyGRPCServer,
|
||||||
|
clusterDeriver: clusterDeriver,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, service := range services {
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return services, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
||||||
|
for _, target := range service.Targets {
|
||||||
|
switch target.TargetType {
|
||||||
|
case reverseproxy.TargetTypePeer:
|
||||||
|
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err)
|
||||||
|
target.Host = unknownHostPlaceholder
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
target.Host = peer.IP.String()
|
||||||
|
case reverseproxy.TargetTypeHost:
|
||||||
|
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
|
||||||
|
target.Host = unknownHostPlaceholder
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
target.Host = resource.Prefix.Addr().String()
|
||||||
|
case reverseproxy.TargetTypeDomain:
|
||||||
|
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
|
||||||
|
target.Host = unknownHostPlaceholder
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
target.Host = resource.Domain
|
||||||
|
case reverseproxy.TargetTypeSubnet:
|
||||||
|
// For subnets we do not do any lookups on the resource
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown target type: %s", target.TargetType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
return service, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.persistNewService(ctx, accountID, service); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta())
|
||||||
|
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||||
|
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return service, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
||||||
|
if m.clusterDeriver != nil {
|
||||||
|
proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain)
|
||||||
|
return status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
|
||||||
|
}
|
||||||
|
service.ProxyCluster = proxyCluster
|
||||||
|
}
|
||||||
|
|
||||||
|
service.AccountID = accountID
|
||||||
|
service.InitNewRecord()
|
||||||
|
|
||||||
|
if err := service.Auth.HashSecrets(); err != nil {
|
||||||
|
return fmt.Errorf("hash secrets: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
keyPair, err := sessionkey.GenerateKeyPair()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate session keys: %w", err)
|
||||||
|
}
|
||||||
|
service.SessionPrivateKey = keyPair.PrivateKey
|
||||||
|
service.SessionPublicKey = keyPair.PublicKey
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) persistNewService(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
||||||
|
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := transaction.CreateService(ctx, service); err != nil {
|
||||||
|
return fmt.Errorf("failed to create service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
|
||||||
|
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
|
||||||
|
if err != nil {
|
||||||
|
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||||
|
return fmt.Errorf("failed to check existing service: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if existingService != nil && existingService.ID != excludeServiceID {
|
||||||
|
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := service.Auth.HashSecrets(); err != nil {
|
||||||
|
return nil, fmt.Errorf("hash secrets: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
updateInfo, err := m.persistServiceUpdate(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceUpdated, service.EventMeta())
|
||||||
|
|
||||||
|
if err := m.replaceHostByLookup(ctx, accountID, service); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.sendServiceUpdateNotifications(service, updateInfo)
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return service, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type serviceUpdateInfo struct {
|
||||||
|
oldCluster string
|
||||||
|
domainChanged bool
|
||||||
|
serviceEnabledChanged bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string, service *reverseproxy.Service) (*serviceUpdateInfo, error) {
|
||||||
|
var updateInfo serviceUpdateInfo
|
||||||
|
|
||||||
|
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
updateInfo.oldCluster = existingService.ProxyCluster
|
||||||
|
updateInfo.domainChanged = existingService.Domain != service.Domain
|
||||||
|
|
||||||
|
if updateInfo.domainChanged {
|
||||||
|
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
service.ProxyCluster = existingService.ProxyCluster
|
||||||
|
}
|
||||||
|
|
||||||
|
m.preserveExistingAuthSecrets(service, existingService)
|
||||||
|
m.preserveServiceMetadata(service, existingService)
|
||||||
|
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
|
||||||
|
|
||||||
|
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := transaction.UpdateService(ctx, service); err != nil {
|
||||||
|
return fmt.Errorf("update service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return &updateInfo, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *reverseproxy.Service) error {
|
||||||
|
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.clusterDeriver != nil {
|
||||||
|
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
|
||||||
|
} else {
|
||||||
|
service.ProxyCluster = newCluster
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reverseproxy.Service) {
|
||||||
|
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
|
||||||
|
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
|
||||||
|
service.Auth.PasswordAuth.Password == "" {
|
||||||
|
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
|
||||||
|
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
|
||||||
|
service.Auth.PinAuth.Pin == "" {
|
||||||
|
service.Auth.PinAuth = existingService.Auth.PinAuth
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) preserveServiceMetadata(service, existingService *reverseproxy.Service) {
|
||||||
|
service.Meta = existingService.Meta
|
||||||
|
service.SessionPrivateKey = existingService.SessionPrivateKey
|
||||||
|
service.SessionPublicKey = existingService.SessionPublicKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) sendServiceUpdateNotifications(service *reverseproxy.Service, updateInfo *serviceUpdateInfo) {
|
||||||
|
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case updateInfo.domainChanged && updateInfo.oldCluster != service.ProxyCluster:
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), updateInfo.oldCluster)
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
|
||||||
|
case !service.Enabled && updateInfo.serviceEnabledChanged:
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), service.ProxyCluster)
|
||||||
|
case service.Enabled && updateInfo.serviceEnabledChanged:
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
|
||||||
|
default:
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", oidcCfg), service.ProxyCluster)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
|
||||||
|
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*reverseproxy.Target) error {
|
||||||
|
for _, target := range targets {
|
||||||
|
switch target.TargetType {
|
||||||
|
case reverseproxy.TargetTypePeer:
|
||||||
|
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||||
|
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||||
|
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
|
||||||
|
}
|
||||||
|
case reverseproxy.TargetTypeHost, reverseproxy.TargetTypeSubnet, reverseproxy.TargetTypeDomain:
|
||||||
|
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||||
|
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||||
|
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||||
|
if err != nil {
|
||||||
|
return status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
var service *reverseproxy.Service
|
||||||
|
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
var err error
|
||||||
|
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta())
|
||||||
|
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||||
|
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
|
||||||
|
// Call this when receiving a gRPC notification that the certificate was issued.
|
||||||
|
func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
||||||
|
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
service.Meta.CertificateIssuedAt = time.Now()
|
||||||
|
|
||||||
|
if err = transaction.UpdateService(ctx, service); err != nil {
|
||||||
|
return fmt.Errorf("failed to update service certificate timestamp: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
|
||||||
|
func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
|
||||||
|
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
service.Meta.Status = string(status)
|
||||||
|
|
||||||
|
if err = transaction.UpdateService(ctx, service); err != nil {
|
||||||
|
return fmt.Errorf("failed to update service status: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID string) error {
|
||||||
|
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||||
|
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
||||||
|
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get services: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, service := range services {
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
||||||
|
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, service := range services {
|
||||||
|
err = m.replaceHostByLookup(ctx, service.AccountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return services, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
|
||||||
|
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return service, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||||
|
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, service := range services {
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return services, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
|
||||||
|
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
||||||
|
if err != nil {
|
||||||
|
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("failed to get service target by resource ID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if target == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return target.ServiceID, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,375 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInitializeServiceForCreate(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
accountID := "test-account"
|
||||||
|
|
||||||
|
t.Run("successful initialization without cluster deriver", func(t *testing.T) {
|
||||||
|
mgr := &managerImpl{
|
||||||
|
clusterDeriver: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
service := &reverseproxy.Service{
|
||||||
|
Domain: "example.com",
|
||||||
|
Auth: reverseproxy.AuthConfig{},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := mgr.initializeServiceForCreate(ctx, accountID, service)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, accountID, service.AccountID)
|
||||||
|
assert.Empty(t, service.ProxyCluster, "proxy cluster should be empty when no deriver")
|
||||||
|
assert.NotEmpty(t, service.ID, "service ID should be initialized")
|
||||||
|
assert.NotEmpty(t, service.SessionPrivateKey, "session private key should be generated")
|
||||||
|
assert.NotEmpty(t, service.SessionPublicKey, "session public key should be generated")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("verifies session keys are different", func(t *testing.T) {
|
||||||
|
mgr := &managerImpl{
|
||||||
|
clusterDeriver: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
service1 := &reverseproxy.Service{Domain: "test1.com", Auth: reverseproxy.AuthConfig{}}
|
||||||
|
service2 := &reverseproxy.Service{Domain: "test2.com", Auth: reverseproxy.AuthConfig{}}
|
||||||
|
|
||||||
|
err1 := mgr.initializeServiceForCreate(ctx, accountID, service1)
|
||||||
|
err2 := mgr.initializeServiceForCreate(ctx, accountID, service2)
|
||||||
|
|
||||||
|
assert.NoError(t, err1)
|
||||||
|
assert.NoError(t, err2)
|
||||||
|
assert.NotEqual(t, service1.SessionPrivateKey, service2.SessionPrivateKey, "private keys should be unique")
|
||||||
|
assert.NotEqual(t, service1.SessionPublicKey, service2.SessionPublicKey, "public keys should be unique")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckDomainAvailable(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
accountID := "test-account"
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
domain string
|
||||||
|
excludeServiceID string
|
||||||
|
setupMock func(*store.MockStore)
|
||||||
|
expectedError bool
|
||||||
|
errorType status.Type
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "domain available - not found",
|
||||||
|
domain: "available.com",
|
||||||
|
excludeServiceID: "",
|
||||||
|
setupMock: func(ms *store.MockStore) {
|
||||||
|
ms.EXPECT().
|
||||||
|
GetServiceByDomain(ctx, accountID, "available.com").
|
||||||
|
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||||
|
},
|
||||||
|
expectedError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "domain already exists",
|
||||||
|
domain: "exists.com",
|
||||||
|
excludeServiceID: "",
|
||||||
|
setupMock: func(ms *store.MockStore) {
|
||||||
|
ms.EXPECT().
|
||||||
|
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||||
|
Return(&reverseproxy.Service{ID: "existing-id", Domain: "exists.com"}, nil)
|
||||||
|
},
|
||||||
|
expectedError: true,
|
||||||
|
errorType: status.AlreadyExists,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "domain exists but excluded (same ID)",
|
||||||
|
domain: "exists.com",
|
||||||
|
excludeServiceID: "service-123",
|
||||||
|
setupMock: func(ms *store.MockStore) {
|
||||||
|
ms.EXPECT().
|
||||||
|
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||||
|
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
||||||
|
},
|
||||||
|
expectedError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "domain exists with different ID",
|
||||||
|
domain: "exists.com",
|
||||||
|
excludeServiceID: "service-456",
|
||||||
|
setupMock: func(ms *store.MockStore) {
|
||||||
|
ms.EXPECT().
|
||||||
|
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||||
|
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
||||||
|
},
|
||||||
|
expectedError: true,
|
||||||
|
errorType: status.AlreadyExists,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "store error (non-NotFound)",
|
||||||
|
domain: "error.com",
|
||||||
|
excludeServiceID: "",
|
||||||
|
setupMock: func(ms *store.MockStore) {
|
||||||
|
ms.EXPECT().
|
||||||
|
GetServiceByDomain(ctx, accountID, "error.com").
|
||||||
|
Return(nil, errors.New("database error"))
|
||||||
|
},
|
||||||
|
expectedError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
tt.setupMock(mockStore)
|
||||||
|
|
||||||
|
mgr := &managerImpl{}
|
||||||
|
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID)
|
||||||
|
|
||||||
|
if tt.expectedError {
|
||||||
|
require.Error(t, err)
|
||||||
|
if tt.errorType != 0 {
|
||||||
|
sErr, ok := status.FromError(err)
|
||||||
|
require.True(t, ok, "error should be a status error")
|
||||||
|
assert.Equal(t, tt.errorType, sErr.Type())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
accountID := "test-account"
|
||||||
|
|
||||||
|
t.Run("empty domain", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().
|
||||||
|
GetServiceByDomain(ctx, accountID, "").
|
||||||
|
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||||
|
|
||||||
|
mgr := &managerImpl{}
|
||||||
|
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "")
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty exclude ID with existing service", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().
|
||||||
|
GetServiceByDomain(ctx, accountID, "test.com").
|
||||||
|
Return(&reverseproxy.Service{ID: "some-id", Domain: "test.com"}, nil)
|
||||||
|
|
||||||
|
mgr := &managerImpl{}
|
||||||
|
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "")
|
||||||
|
|
||||||
|
assert.Error(t, err)
|
||||||
|
sErr, ok := status.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, status.AlreadyExists, sErr.Type())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil existing service with nil error", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().
|
||||||
|
GetServiceByDomain(ctx, accountID, "nil.com").
|
||||||
|
Return(nil, nil)
|
||||||
|
|
||||||
|
mgr := &managerImpl{}
|
||||||
|
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "")
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPersistNewService(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
accountID := "test-account"
|
||||||
|
|
||||||
|
t.Run("successful service creation with no targets", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
service := &reverseproxy.Service{
|
||||||
|
ID: "service-123",
|
||||||
|
Domain: "new.com",
|
||||||
|
Targets: []*reverseproxy.Target{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mock ExecuteInTransaction to execute the function immediately
|
||||||
|
mockStore.EXPECT().
|
||||||
|
ExecuteInTransaction(ctx, gomock.Any()).
|
||||||
|
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
|
||||||
|
// Create another mock for the transaction
|
||||||
|
txMock := store.NewMockStore(ctrl)
|
||||||
|
txMock.EXPECT().
|
||||||
|
GetServiceByDomain(ctx, accountID, "new.com").
|
||||||
|
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||||
|
txMock.EXPECT().
|
||||||
|
CreateService(ctx, service).
|
||||||
|
Return(nil)
|
||||||
|
|
||||||
|
return fn(txMock)
|
||||||
|
})
|
||||||
|
|
||||||
|
mgr := &managerImpl{store: mockStore}
|
||||||
|
err := mgr.persistNewService(ctx, accountID, service)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("domain already exists", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
service := &reverseproxy.Service{
|
||||||
|
ID: "service-123",
|
||||||
|
Domain: "existing.com",
|
||||||
|
Targets: []*reverseproxy.Target{},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockStore.EXPECT().
|
||||||
|
ExecuteInTransaction(ctx, gomock.Any()).
|
||||||
|
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
|
||||||
|
txMock := store.NewMockStore(ctrl)
|
||||||
|
txMock.EXPECT().
|
||||||
|
GetServiceByDomain(ctx, accountID, "existing.com").
|
||||||
|
Return(&reverseproxy.Service{ID: "other-id", Domain: "existing.com"}, nil)
|
||||||
|
|
||||||
|
return fn(txMock)
|
||||||
|
})
|
||||||
|
|
||||||
|
mgr := &managerImpl{store: mockStore}
|
||||||
|
err := mgr.persistNewService(ctx, accountID, service)
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
sErr, ok := status.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, status.AlreadyExists, sErr.Type())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
func TestPreserveExistingAuthSecrets(t *testing.T) {
|
||||||
|
mgr := &managerImpl{}
|
||||||
|
|
||||||
|
t.Run("preserve password when empty", func(t *testing.T) {
|
||||||
|
existing := &reverseproxy.Service{
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: "hashed-password",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
updated := &reverseproxy.Service{
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr.preserveExistingAuthSecrets(updated, existing)
|
||||||
|
|
||||||
|
assert.Equal(t, existing.Auth.PasswordAuth, updated.Auth.PasswordAuth)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserve pin when empty", func(t *testing.T) {
|
||||||
|
existing := &reverseproxy.Service{
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
PinAuth: &reverseproxy.PINAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Pin: "hashed-pin",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
updated := &reverseproxy.Service{
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
PinAuth: &reverseproxy.PINAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Pin: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr.preserveExistingAuthSecrets(updated, existing)
|
||||||
|
|
||||||
|
assert.Equal(t, existing.Auth.PinAuth, updated.Auth.PinAuth)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("do not preserve when password is provided", func(t *testing.T) {
|
||||||
|
existing := &reverseproxy.Service{
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: "old-password",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
updated := &reverseproxy.Service{
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: "new-password",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr.preserveExistingAuthSecrets(updated, existing)
|
||||||
|
|
||||||
|
assert.Equal(t, "new-password", updated.Auth.PasswordAuth.Password)
|
||||||
|
assert.NotEqual(t, existing.Auth.PasswordAuth, updated.Auth.PasswordAuth)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPreserveServiceMetadata(t *testing.T) {
|
||||||
|
mgr := &managerImpl{}
|
||||||
|
|
||||||
|
existing := &reverseproxy.Service{
|
||||||
|
Meta: reverseproxy.ServiceMeta{
|
||||||
|
CertificateIssuedAt: time.Now(),
|
||||||
|
Status: "active",
|
||||||
|
},
|
||||||
|
SessionPrivateKey: "private-key",
|
||||||
|
SessionPublicKey: "public-key",
|
||||||
|
}
|
||||||
|
|
||||||
|
updated := &reverseproxy.Service{
|
||||||
|
Domain: "updated.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr.preserveServiceMetadata(updated, existing)
|
||||||
|
|
||||||
|
assert.Equal(t, existing.Meta, updated.Meta)
|
||||||
|
assert.Equal(t, existing.SessionPrivateKey, updated.SessionPrivateKey)
|
||||||
|
assert.Equal(t, existing.SessionPublicKey, updated.SessionPublicKey)
|
||||||
|
}
|
||||||
463
management/internals/modules/reverseproxy/reverseproxy.go
Normal file
463
management/internals/modules/reverseproxy/reverseproxy.go
Normal file
@@ -0,0 +1,463 @@
|
|||||||
|
package reverseproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/xid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||||
|
"github.com/netbirdio/netbird/util/crypt"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Operation string
|
||||||
|
|
||||||
|
const (
|
||||||
|
Create Operation = "create"
|
||||||
|
Update Operation = "update"
|
||||||
|
Delete Operation = "delete"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ProxyStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
StatusPending ProxyStatus = "pending"
|
||||||
|
StatusActive ProxyStatus = "active"
|
||||||
|
StatusTunnelNotCreated ProxyStatus = "tunnel_not_created"
|
||||||
|
StatusCertificatePending ProxyStatus = "certificate_pending"
|
||||||
|
StatusCertificateFailed ProxyStatus = "certificate_failed"
|
||||||
|
StatusError ProxyStatus = "error"
|
||||||
|
|
||||||
|
TargetTypePeer = "peer"
|
||||||
|
TargetTypeHost = "host"
|
||||||
|
TargetTypeDomain = "domain"
|
||||||
|
TargetTypeSubnet = "subnet"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Target struct {
|
||||||
|
ID uint `gorm:"primaryKey" json:"-"`
|
||||||
|
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
|
||||||
|
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
|
||||||
|
Path *string `json:"path,omitempty"`
|
||||||
|
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
|
||||||
|
Port int `gorm:"index:idx_target_port" json:"port"`
|
||||||
|
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
|
||||||
|
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
|
||||||
|
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
|
||||||
|
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PasswordAuthConfig struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PINAuthConfig struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Pin string `json:"pin"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BearerAuthConfig struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthConfig struct {
|
||||||
|
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"`
|
||||||
|
PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"`
|
||||||
|
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthConfig) HashSecrets() error {
|
||||||
|
if a.PasswordAuth != nil && a.PasswordAuth.Enabled && a.PasswordAuth.Password != "" {
|
||||||
|
hashedPassword, err := argon2id.Hash(a.PasswordAuth.Password)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("hash password: %w", err)
|
||||||
|
}
|
||||||
|
a.PasswordAuth.Password = hashedPassword
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.PinAuth != nil && a.PinAuth.Enabled && a.PinAuth.Pin != "" {
|
||||||
|
hashedPin, err := argon2id.Hash(a.PinAuth.Pin)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("hash pin: %w", err)
|
||||||
|
}
|
||||||
|
a.PinAuth.Pin = hashedPin
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthConfig) ClearSecrets() {
|
||||||
|
if a.PasswordAuth != nil {
|
||||||
|
a.PasswordAuth.Password = ""
|
||||||
|
}
|
||||||
|
if a.PinAuth != nil {
|
||||||
|
a.PinAuth.Pin = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type OIDCValidationConfig struct {
|
||||||
|
Issuer string
|
||||||
|
Audiences []string
|
||||||
|
KeysLocation string
|
||||||
|
MaxTokenAgeSeconds int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServiceMeta struct {
|
||||||
|
CreatedAt time.Time
|
||||||
|
CertificateIssuedAt time.Time
|
||||||
|
Status string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
ID string `gorm:"primaryKey"`
|
||||||
|
AccountID string `gorm:"index"`
|
||||||
|
Name string
|
||||||
|
Domain string `gorm:"index"`
|
||||||
|
ProxyCluster string `gorm:"index"`
|
||||||
|
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
|
||||||
|
Enabled bool
|
||||||
|
PassHostHeader bool
|
||||||
|
RewriteRedirects bool
|
||||||
|
Auth AuthConfig `gorm:"serializer:json"`
|
||||||
|
Meta ServiceMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||||
|
SessionPrivateKey string `gorm:"column:session_private_key"`
|
||||||
|
SessionPublicKey string `gorm:"column:session_public_key"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service {
|
||||||
|
for _, target := range targets {
|
||||||
|
target.AccountID = accountID
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &Service{
|
||||||
|
AccountID: accountID,
|
||||||
|
Name: name,
|
||||||
|
Domain: domain,
|
||||||
|
ProxyCluster: proxyCluster,
|
||||||
|
Targets: targets,
|
||||||
|
Enabled: enabled,
|
||||||
|
}
|
||||||
|
s.InitNewRecord()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitNewRecord generates a new unique ID and resets metadata for a newly created
|
||||||
|
// Service record. This overwrites any existing ID and Meta fields and should
|
||||||
|
// only be called during initial creation, not for updates.
|
||||||
|
func (s *Service) InitNewRecord() {
|
||||||
|
s.ID = xid.New().String()
|
||||||
|
s.Meta = ServiceMeta{
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
Status: string(StatusPending),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) ToAPIResponse() *api.Service {
|
||||||
|
s.Auth.ClearSecrets()
|
||||||
|
|
||||||
|
authConfig := api.ServiceAuthConfig{}
|
||||||
|
|
||||||
|
if s.Auth.PasswordAuth != nil {
|
||||||
|
authConfig.PasswordAuth = &api.PasswordAuthConfig{
|
||||||
|
Enabled: s.Auth.PasswordAuth.Enabled,
|
||||||
|
Password: s.Auth.PasswordAuth.Password,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Auth.PinAuth != nil {
|
||||||
|
authConfig.PinAuth = &api.PINAuthConfig{
|
||||||
|
Enabled: s.Auth.PinAuth.Enabled,
|
||||||
|
Pin: s.Auth.PinAuth.Pin,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Auth.BearerAuth != nil {
|
||||||
|
authConfig.BearerAuth = &api.BearerAuthConfig{
|
||||||
|
Enabled: s.Auth.BearerAuth.Enabled,
|
||||||
|
DistributionGroups: &s.Auth.BearerAuth.DistributionGroups,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert internal targets to API targets
|
||||||
|
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
|
||||||
|
for _, target := range s.Targets {
|
||||||
|
apiTargets = append(apiTargets, api.ServiceTarget{
|
||||||
|
Path: target.Path,
|
||||||
|
Host: &target.Host,
|
||||||
|
Port: target.Port,
|
||||||
|
Protocol: api.ServiceTargetProtocol(target.Protocol),
|
||||||
|
TargetId: target.TargetId,
|
||||||
|
TargetType: api.ServiceTargetTargetType(target.TargetType),
|
||||||
|
Enabled: target.Enabled,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
meta := api.ServiceMeta{
|
||||||
|
CreatedAt: s.Meta.CreatedAt,
|
||||||
|
Status: api.ServiceMetaStatus(s.Meta.Status),
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.Meta.CertificateIssuedAt.IsZero() {
|
||||||
|
meta.CertificateIssuedAt = &s.Meta.CertificateIssuedAt
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &api.Service{
|
||||||
|
Id: s.ID,
|
||||||
|
Name: s.Name,
|
||||||
|
Domain: s.Domain,
|
||||||
|
Targets: apiTargets,
|
||||||
|
Enabled: s.Enabled,
|
||||||
|
PassHostHeader: &s.PassHostHeader,
|
||||||
|
RewriteRedirects: &s.RewriteRedirects,
|
||||||
|
Auth: authConfig,
|
||||||
|
Meta: meta,
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.ProxyCluster != "" {
|
||||||
|
resp.ProxyCluster = &s.ProxyCluster
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping {
|
||||||
|
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
|
||||||
|
for _, target := range s.Targets {
|
||||||
|
if !target.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Make path prefix stripping configurable per-target.
|
||||||
|
// Currently the matching prefix is baked into the target URL path,
|
||||||
|
// so the proxy strips-then-re-adds it (effectively a no-op).
|
||||||
|
targetURL := url.URL{
|
||||||
|
Scheme: target.Protocol,
|
||||||
|
Host: target.Host,
|
||||||
|
Path: "/", // TODO: support service path
|
||||||
|
}
|
||||||
|
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
|
||||||
|
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port))
|
||||||
|
}
|
||||||
|
|
||||||
|
path := "/"
|
||||||
|
if target.Path != nil {
|
||||||
|
path = *target.Path
|
||||||
|
}
|
||||||
|
pathMappings = append(pathMappings, &proto.PathMapping{
|
||||||
|
Path: path,
|
||||||
|
Target: targetURL.String(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
auth := &proto.Authentication{
|
||||||
|
SessionKey: s.SessionPublicKey,
|
||||||
|
MaxSessionAgeSeconds: int64((time.Hour * 24).Seconds()),
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled {
|
||||||
|
auth.Password = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled {
|
||||||
|
auth.Pin = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
|
||||||
|
auth.Oidc = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return &proto.ProxyMapping{
|
||||||
|
Type: operationToProtoType(operation),
|
||||||
|
Id: s.ID,
|
||||||
|
Domain: s.Domain,
|
||||||
|
Path: pathMappings,
|
||||||
|
AuthToken: authToken,
|
||||||
|
Auth: auth,
|
||||||
|
AccountId: s.AccountID,
|
||||||
|
PassHostHeader: s.PassHostHeader,
|
||||||
|
RewriteRedirects: s.RewriteRedirects,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
|
||||||
|
switch op {
|
||||||
|
case Create:
|
||||||
|
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
|
||||||
|
case Update:
|
||||||
|
return proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED
|
||||||
|
case Delete:
|
||||||
|
return proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED
|
||||||
|
default:
|
||||||
|
log.Fatalf("unknown operation type: %v", op)
|
||||||
|
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isDefaultPort reports whether port is the standard default for the given scheme
|
||||||
|
// (443 for https, 80 for http).
|
||||||
|
func isDefaultPort(scheme string, port int) bool {
|
||||||
|
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
|
||||||
|
s.Name = req.Name
|
||||||
|
s.Domain = req.Domain
|
||||||
|
s.AccountID = accountID
|
||||||
|
|
||||||
|
targets := make([]*Target, 0, len(req.Targets))
|
||||||
|
for _, apiTarget := range req.Targets {
|
||||||
|
target := &Target{
|
||||||
|
AccountID: accountID,
|
||||||
|
Path: apiTarget.Path,
|
||||||
|
Port: apiTarget.Port,
|
||||||
|
Protocol: string(apiTarget.Protocol),
|
||||||
|
TargetId: apiTarget.TargetId,
|
||||||
|
TargetType: string(apiTarget.TargetType),
|
||||||
|
Enabled: apiTarget.Enabled,
|
||||||
|
}
|
||||||
|
if apiTarget.Host != nil {
|
||||||
|
target.Host = *apiTarget.Host
|
||||||
|
}
|
||||||
|
targets = append(targets, target)
|
||||||
|
}
|
||||||
|
s.Targets = targets
|
||||||
|
|
||||||
|
s.Enabled = req.Enabled
|
||||||
|
|
||||||
|
if req.PassHostHeader != nil {
|
||||||
|
s.PassHostHeader = *req.PassHostHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.RewriteRedirects != nil {
|
||||||
|
s.RewriteRedirects = *req.RewriteRedirects
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Auth.PasswordAuth != nil {
|
||||||
|
s.Auth.PasswordAuth = &PasswordAuthConfig{
|
||||||
|
Enabled: req.Auth.PasswordAuth.Enabled,
|
||||||
|
Password: req.Auth.PasswordAuth.Password,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Auth.PinAuth != nil {
|
||||||
|
s.Auth.PinAuth = &PINAuthConfig{
|
||||||
|
Enabled: req.Auth.PinAuth.Enabled,
|
||||||
|
Pin: req.Auth.PinAuth.Pin,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Auth.BearerAuth != nil {
|
||||||
|
bearerAuth := &BearerAuthConfig{
|
||||||
|
Enabled: req.Auth.BearerAuth.Enabled,
|
||||||
|
}
|
||||||
|
if req.Auth.BearerAuth.DistributionGroups != nil {
|
||||||
|
bearerAuth.DistributionGroups = *req.Auth.BearerAuth.DistributionGroups
|
||||||
|
}
|
||||||
|
s.Auth.BearerAuth = bearerAuth
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Validate() error {
|
||||||
|
if s.Name == "" {
|
||||||
|
return errors.New("service name is required")
|
||||||
|
}
|
||||||
|
if len(s.Name) > 255 {
|
||||||
|
return errors.New("service name exceeds maximum length of 255 characters")
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Domain == "" {
|
||||||
|
return errors.New("service domain is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(s.Targets) == 0 {
|
||||||
|
return errors.New("at least one target is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, target := range s.Targets {
|
||||||
|
switch target.TargetType {
|
||||||
|
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
||||||
|
// host field will be ignored
|
||||||
|
case TargetTypeSubnet:
|
||||||
|
if target.Host == "" {
|
||||||
|
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
|
||||||
|
}
|
||||||
|
if target.TargetId == "" {
|
||||||
|
return fmt.Errorf("target %d has empty target_id", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) EventMeta() map[string]any {
|
||||||
|
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Copy() *Service {
|
||||||
|
targets := make([]*Target, len(s.Targets))
|
||||||
|
for i, target := range s.Targets {
|
||||||
|
targetCopy := *target
|
||||||
|
targets[i] = &targetCopy
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Service{
|
||||||
|
ID: s.ID,
|
||||||
|
AccountID: s.AccountID,
|
||||||
|
Name: s.Name,
|
||||||
|
Domain: s.Domain,
|
||||||
|
ProxyCluster: s.ProxyCluster,
|
||||||
|
Targets: targets,
|
||||||
|
Enabled: s.Enabled,
|
||||||
|
PassHostHeader: s.PassHostHeader,
|
||||||
|
RewriteRedirects: s.RewriteRedirects,
|
||||||
|
Auth: s.Auth,
|
||||||
|
Meta: s.Meta,
|
||||||
|
SessionPrivateKey: s.SessionPrivateKey,
|
||||||
|
SessionPublicKey: s.SessionPublicKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||||
|
if enc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.SessionPrivateKey != "" {
|
||||||
|
var err error
|
||||||
|
s.SessionPrivateKey, err = enc.Encrypt(s.SessionPrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||||
|
if enc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.SessionPrivateKey != "" {
|
||||||
|
var err error
|
||||||
|
s.SessionPrivateKey, err = enc.Decrypt(s.SessionPrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
405
management/internals/modules/reverseproxy/reverseproxy_test.go
Normal file
405
management/internals/modules/reverseproxy/reverseproxy_test.go
Normal file
@@ -0,0 +1,405 @@
|
|||||||
|
package reverseproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func validProxy() *Service {
|
||||||
|
return &Service{
|
||||||
|
Name: "test",
|
||||||
|
Domain: "example.com",
|
||||||
|
Targets: []*Target{
|
||||||
|
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 80, Protocol: "http", Enabled: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_Valid(t *testing.T) {
|
||||||
|
require.NoError(t, validProxy().Validate())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_EmptyName(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Name = ""
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "name is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_EmptyDomain(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Domain = ""
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "domain is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_NoTargets(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets = nil
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "at least one target")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_EmptyTargetId(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].TargetId = ""
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "empty target_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_InvalidTargetType(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].TargetType = "invalid"
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "invalid target_type")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_ResourceTarget(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets = append(rp.Targets, &Target{
|
||||||
|
TargetId: "resource-1",
|
||||||
|
TargetType: TargetTypeHost,
|
||||||
|
Host: "example.org",
|
||||||
|
Port: 443,
|
||||||
|
Protocol: "https",
|
||||||
|
Enabled: true,
|
||||||
|
})
|
||||||
|
require.NoError(t, rp.Validate())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_MultipleTargetsOneInvalid(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets = append(rp.Targets, &Target{
|
||||||
|
TargetId: "",
|
||||||
|
TargetType: TargetTypePeer,
|
||||||
|
Host: "10.0.0.2",
|
||||||
|
Port: 80,
|
||||||
|
Protocol: "http",
|
||||||
|
Enabled: true,
|
||||||
|
})
|
||||||
|
err := rp.Validate()
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "target 1")
|
||||||
|
assert.Contains(t, err.Error(), "empty target_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsDefaultPort(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
scheme string
|
||||||
|
port int
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"http", 80, true},
|
||||||
|
{"https", 443, true},
|
||||||
|
{"http", 443, false},
|
||||||
|
{"https", 80, false},
|
||||||
|
{"http", 8080, false},
|
||||||
|
{"https", 8443, false},
|
||||||
|
{"http", 0, false},
|
||||||
|
{"https", 0, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(fmt.Sprintf("%s/%d", tt.scheme, tt.port), func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.want, isDefaultPort(tt.scheme, tt.port))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProtoMapping_PortInTargetURL(t *testing.T) {
|
||||||
|
oidcConfig := OIDCValidationConfig{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
protocol string
|
||||||
|
host string
|
||||||
|
port int
|
||||||
|
wantTarget string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "http with default port 80 omits port",
|
||||||
|
protocol: "http",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 80,
|
||||||
|
wantTarget: "http://10.0.0.1/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "https with default port 443 omits port",
|
||||||
|
protocol: "https",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 443,
|
||||||
|
wantTarget: "https://10.0.0.1/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "port 0 omits port",
|
||||||
|
protocol: "http",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 0,
|
||||||
|
wantTarget: "http://10.0.0.1/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-default port is included",
|
||||||
|
protocol: "http",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 8080,
|
||||||
|
wantTarget: "http://10.0.0.1:8080/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "https with non-default port is included",
|
||||||
|
protocol: "https",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 8443,
|
||||||
|
wantTarget: "https://10.0.0.1:8443/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "http port 443 is included",
|
||||||
|
protocol: "http",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 443,
|
||||||
|
wantTarget: "http://10.0.0.1:443/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "https port 80 is included",
|
||||||
|
protocol: "https",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 80,
|
||||||
|
wantTarget: "https://10.0.0.1:80/",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rp := &Service{
|
||||||
|
ID: "test-id",
|
||||||
|
AccountID: "acc-1",
|
||||||
|
Domain: "example.com",
|
||||||
|
Targets: []*Target{
|
||||||
|
{
|
||||||
|
TargetId: "peer-1",
|
||||||
|
TargetType: TargetTypePeer,
|
||||||
|
Host: tt.host,
|
||||||
|
Port: tt.port,
|
||||||
|
Protocol: tt.protocol,
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
pm := rp.ToProtoMapping(Create, "token", oidcConfig)
|
||||||
|
require.Len(t, pm.Path, 1, "should have one path mapping")
|
||||||
|
assert.Equal(t, tt.wantTarget, pm.Path[0].Target)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) {
|
||||||
|
rp := &Service{
|
||||||
|
ID: "test-id",
|
||||||
|
AccountID: "acc-1",
|
||||||
|
Domain: "example.com",
|
||||||
|
Targets: []*Target{
|
||||||
|
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 8080, Protocol: "http", Enabled: false},
|
||||||
|
{TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
pm := rp.ToProtoMapping(Create, "token", OIDCValidationConfig{})
|
||||||
|
require.Len(t, pm.Path, 1)
|
||||||
|
assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProtoMapping_OperationTypes(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
tests := []struct {
|
||||||
|
op Operation
|
||||||
|
want proto.ProxyMappingUpdateType
|
||||||
|
}{
|
||||||
|
{Create, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED},
|
||||||
|
{Update, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED},
|
||||||
|
{Delete, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(string(tt.op), func(t *testing.T) {
|
||||||
|
pm := rp.ToProtoMapping(tt.op, "", OIDCValidationConfig{})
|
||||||
|
assert.Equal(t, tt.want, pm.Type)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthConfig_HashSecrets(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config *AuthConfig
|
||||||
|
wantErr bool
|
||||||
|
validate func(*testing.T, *AuthConfig)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "hash password successfully",
|
||||||
|
config: &AuthConfig{
|
||||||
|
PasswordAuth: &PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: "testPassword123",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
validate: func(t *testing.T, config *AuthConfig) {
|
||||||
|
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
|
||||||
|
t.Errorf("Password not hashed with argon2id, got: %s", config.PasswordAuth.Password)
|
||||||
|
}
|
||||||
|
// Verify the hash can be verified
|
||||||
|
if err := argon2id.Verify("testPassword123", config.PasswordAuth.Password); err != nil {
|
||||||
|
t.Errorf("Hash verification failed: %v", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hash PIN successfully",
|
||||||
|
config: &AuthConfig{
|
||||||
|
PinAuth: &PINAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Pin: "123456",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
validate: func(t *testing.T, config *AuthConfig) {
|
||||||
|
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
|
||||||
|
t.Errorf("PIN not hashed with argon2id, got: %s", config.PinAuth.Pin)
|
||||||
|
}
|
||||||
|
// Verify the hash can be verified
|
||||||
|
if err := argon2id.Verify("123456", config.PinAuth.Pin); err != nil {
|
||||||
|
t.Errorf("Hash verification failed: %v", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hash both password and PIN",
|
||||||
|
config: &AuthConfig{
|
||||||
|
PasswordAuth: &PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: "password",
|
||||||
|
},
|
||||||
|
PinAuth: &PINAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Pin: "9999",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
validate: func(t *testing.T, config *AuthConfig) {
|
||||||
|
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
|
||||||
|
t.Errorf("Password not hashed with argon2id")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
|
||||||
|
t.Errorf("PIN not hashed with argon2id")
|
||||||
|
}
|
||||||
|
if err := argon2id.Verify("password", config.PasswordAuth.Password); err != nil {
|
||||||
|
t.Errorf("Password hash verification failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := argon2id.Verify("9999", config.PinAuth.Pin); err != nil {
|
||||||
|
t.Errorf("PIN hash verification failed: %v", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "skip disabled password auth",
|
||||||
|
config: &AuthConfig{
|
||||||
|
PasswordAuth: &PasswordAuthConfig{
|
||||||
|
Enabled: false,
|
||||||
|
Password: "password",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
validate: func(t *testing.T, config *AuthConfig) {
|
||||||
|
if config.PasswordAuth.Password != "password" {
|
||||||
|
t.Errorf("Disabled password auth should not be hashed")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "skip empty password",
|
||||||
|
config: &AuthConfig{
|
||||||
|
PasswordAuth: &PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
validate: func(t *testing.T, config *AuthConfig) {
|
||||||
|
if config.PasswordAuth.Password != "" {
|
||||||
|
t.Errorf("Empty password should remain empty")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "skip nil password auth",
|
||||||
|
config: &AuthConfig{
|
||||||
|
PasswordAuth: nil,
|
||||||
|
PinAuth: &PINAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Pin: "1234",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
validate: func(t *testing.T, config *AuthConfig) {
|
||||||
|
if config.PasswordAuth != nil {
|
||||||
|
t.Errorf("PasswordAuth should remain nil")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
|
||||||
|
t.Errorf("PIN should still be hashed")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := tt.config.HashSecrets()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("HashSecrets() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.validate != nil {
|
||||||
|
tt.validate(t, tt.config)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthConfig_HashSecrets_VerifyIncorrectSecret(t *testing.T) {
|
||||||
|
config := &AuthConfig{
|
||||||
|
PasswordAuth: &PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: "correctPassword",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := config.HashSecrets(); err != nil {
|
||||||
|
t.Fatalf("HashSecrets() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify with wrong password should fail
|
||||||
|
err := argon2id.Verify("wrongPassword", config.PasswordAuth.Password)
|
||||||
|
if !errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
|
||||||
|
t.Errorf("Expected ErrMismatchedHashAndPassword, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthConfig_ClearSecrets(t *testing.T) {
|
||||||
|
config := &AuthConfig{
|
||||||
|
PasswordAuth: &PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: "hashedPassword",
|
||||||
|
},
|
||||||
|
PinAuth: &PINAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Pin: "hashedPin",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config.ClearSecrets()
|
||||||
|
|
||||||
|
if config.PasswordAuth.Password != "" {
|
||||||
|
t.Errorf("Password not cleared, got: %s", config.PasswordAuth.Password)
|
||||||
|
}
|
||||||
|
if config.PinAuth.Pin != "" {
|
||||||
|
t.Errorf("PIN not cleared, got: %s", config.PinAuth.Pin)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
package sessionkey
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/proxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KeyPair struct {
|
||||||
|
PrivateKey string
|
||||||
|
PublicKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Claims struct {
|
||||||
|
jwt.RegisteredClaims
|
||||||
|
Method auth.Method `json:"method"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateKeyPair() (*KeyPair, error) {
|
||||||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generate ed25519 key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &KeyPair{
|
||||||
|
PrivateKey: base64.StdEncoding.EncodeToString(priv),
|
||||||
|
PublicKey: base64.StdEncoding.EncodeToString(pub),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration time.Duration) (string, error) {
|
||||||
|
privKeyBytes, err := base64.StdEncoding.DecodeString(privKeyB64)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("decode private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(privKeyBytes) != ed25519.PrivateKeySize {
|
||||||
|
return "", fmt.Errorf("invalid private key size: got %d, want %d", len(privKeyBytes), ed25519.PrivateKeySize)
|
||||||
|
}
|
||||||
|
|
||||||
|
privKey := ed25519.PrivateKey(privKeyBytes)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
claims := Claims{
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Issuer: auth.SessionJWTIssuer,
|
||||||
|
Subject: userID,
|
||||||
|
Audience: jwt.ClaimStrings{domain},
|
||||||
|
ExpiresAt: jwt.NewNumericDate(now.Add(expiration)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
|
NotBefore: jwt.NewNumericDate(now),
|
||||||
|
},
|
||||||
|
Method: method,
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
|
||||||
|
signedToken, err := token.SignedString(privKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("sign token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return signedToken, nil
|
||||||
|
}
|
||||||
@@ -21,6 +21,8 @@ import (
|
|||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
"github.com/netbirdio/netbird/formatter/hook"
|
"github.com/netbirdio/netbird/formatter/hook"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
|
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||||
@@ -92,7 +94,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
|||||||
|
|
||||||
func (s *BaseServer) APIHandler() http.Handler {
|
func (s *BaseServer) APIHandler() http.Handler {
|
||||||
return Create(s, func() http.Handler {
|
return Create(s, func() http.Handler {
|
||||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager())
|
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create API handler: %v", err)
|
log.Fatalf("failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
@@ -120,11 +122,13 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
|||||||
realip.WithTrustedProxiesCount(trustedProxiesCount),
|
realip.WithTrustedProxiesCount(trustedProxiesCount),
|
||||||
realip.WithHeaders([]string{realip.XForwardedFor, realip.XRealIp}),
|
realip.WithHeaders([]string{realip.XForwardedFor, realip.XRealIp}),
|
||||||
}
|
}
|
||||||
|
proxyUnary, proxyStream, proxyAuthClose := nbgrpc.NewProxyAuthInterceptors(s.Store())
|
||||||
|
s.proxyAuthClose = proxyAuthClose
|
||||||
gRPCOpts := []grpc.ServerOption{
|
gRPCOpts := []grpc.ServerOption{
|
||||||
grpc.KeepaliveEnforcementPolicy(kaep),
|
grpc.KeepaliveEnforcementPolicy(kaep),
|
||||||
grpc.KeepaliveParams(kasp),
|
grpc.KeepaliveParams(kasp),
|
||||||
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor),
|
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor, proxyUnary),
|
||||||
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
|
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor, proxyStream),
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.Config.HttpConfig.LetsEncryptDomain != "" {
|
if s.Config.HttpConfig.LetsEncryptDomain != "" {
|
||||||
@@ -150,10 +154,53 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
|||||||
}
|
}
|
||||||
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
|
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
|
||||||
|
|
||||||
|
mgmtProto.RegisterProxyServiceServer(gRPCAPIHandler, s.ReverseProxyGRPCServer())
|
||||||
|
log.Info("ProxyService registered on gRPC server")
|
||||||
|
|
||||||
return gRPCAPIHandler
|
return gRPCAPIHandler
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
||||||
|
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
||||||
|
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager())
|
||||||
|
s.AfterInit(func(s *BaseServer) {
|
||||||
|
proxyService.SetProxyManager(s.ReverseProxyManager())
|
||||||
|
})
|
||||||
|
return proxyService
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig {
|
||||||
|
return Create(s, func() nbgrpc.ProxyOIDCConfig {
|
||||||
|
return nbgrpc.ProxyOIDCConfig{
|
||||||
|
Issuer: s.Config.HttpConfig.AuthIssuer,
|
||||||
|
// todo: double check auth clientID value
|
||||||
|
ClientID: s.Config.HttpConfig.AuthClientID, // Reuse dashboard client
|
||||||
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
CallbackURL: s.Config.HttpConfig.AuthCallbackURL,
|
||||||
|
HMACKey: []byte(s.Config.DataStoreEncryptionKey), // Use the datastore encryption key for OIDC state HMACs, this should ensure all management instances are using the same key.
|
||||||
|
Audience: s.Config.HttpConfig.AuthAudience,
|
||||||
|
KeysLocation: s.Config.HttpConfig.AuthKeysLocation,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
|
||||||
|
return Create(s, func() *nbgrpc.OneTimeTokenStore {
|
||||||
|
tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute)
|
||||||
|
log.Info("One-time token store initialized for proxy authentication")
|
||||||
|
return tokenStore
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
|
||||||
|
return Create(s, func() accesslogs.Manager {
|
||||||
|
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager())
|
||||||
|
return accessLogManager
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
|
func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
|
||||||
// Load server's certificate and private key
|
// Load server's certificate and private key
|
||||||
serverCert, err := tls.LoadX509KeyPair(certFile, certKey)
|
serverCert, err := tls.LoadX509KeyPair(certFile, certKey)
|
||||||
|
|||||||
@@ -100,6 +100,8 @@ type HttpServerConfig struct {
|
|||||||
CertFile string
|
CertFile string
|
||||||
// CertKey is the location of the certificate private key
|
// CertKey is the location of the certificate private key
|
||||||
CertKey string
|
CertKey string
|
||||||
|
// AuthClientID is the client id used for proxy SSO auth
|
||||||
|
AuthClientID string
|
||||||
// AuthAudience identifies the recipients that the JWT is intended for (aud in JWT)
|
// AuthAudience identifies the recipients that the JWT is intended for (aud in JWT)
|
||||||
AuthAudience string
|
AuthAudience string
|
||||||
// CLIAuthAudience identifies the client app recipients that the JWT is intended for (aud in JWT)
|
// CLIAuthAudience identifies the client app recipients that the JWT is intended for (aud in JWT)
|
||||||
@@ -117,6 +119,8 @@ type HttpServerConfig struct {
|
|||||||
IdpSignKeyRefreshEnabled bool
|
IdpSignKeyRefreshEnabled bool
|
||||||
// Extra audience
|
// Extra audience
|
||||||
ExtraAuthAudience string
|
ExtraAuthAudience string
|
||||||
|
// AuthCallbackDomain contains the callback domain
|
||||||
|
AuthCallbackURL string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Host represents a Netbird host (e.g. STUN, TURN, Signal)
|
// Host represents a Netbird host (e.g. STUN, TURN, Signal)
|
||||||
|
|||||||
@@ -8,6 +8,9 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||||
|
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||||
@@ -98,6 +101,11 @@ func (s *BaseServer) AccountManager() account.Manager {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create account manager: %v", err)
|
log.Fatalf("failed to create account manager: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.AfterInit(func(s *BaseServer) {
|
||||||
|
accountManager.SetServiceManager(s.ReverseProxyManager())
|
||||||
|
})
|
||||||
|
|
||||||
return accountManager
|
return accountManager
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -154,7 +162,7 @@ func (s *BaseServer) GroupsManager() groups.Manager {
|
|||||||
|
|
||||||
func (s *BaseServer) ResourcesManager() resources.Manager {
|
func (s *BaseServer) ResourcesManager() resources.Manager {
|
||||||
return Create(s, func() resources.Manager {
|
return Create(s, func() resources.Manager {
|
||||||
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager())
|
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ReverseProxyManager())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -181,3 +189,16 @@ func (s *BaseServer) RecordsManager() records.Manager {
|
|||||||
return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager())
|
return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) ReverseProxyManager() reverseproxy.Manager {
|
||||||
|
return Create(s, func() reverseproxy.Manager {
|
||||||
|
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), s.ReverseProxyDomainManager())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
||||||
|
return Create(s, func() *manager.Manager {
|
||||||
|
m := manager.NewManager(s.Store(), s.ReverseProxyGRPCServer(), s.PermissionsManager())
|
||||||
|
return &m
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -18,10 +18,9 @@ import (
|
|||||||
"golang.org/x/net/http2/h2c"
|
"golang.org/x/net/http2/h2c"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
"github.com/netbirdio/netbird/management/server/metrics"
|
"github.com/netbirdio/netbird/management/server/metrics"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/util/wsproxy"
|
"github.com/netbirdio/netbird/util/wsproxy"
|
||||||
@@ -59,6 +58,8 @@ type BaseServer struct {
|
|||||||
mgmtMetricsPort int
|
mgmtMetricsPort int
|
||||||
mgmtPort int
|
mgmtPort int
|
||||||
|
|
||||||
|
proxyAuthClose func()
|
||||||
|
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
certManager *autocert.Manager
|
certManager *autocert.Manager
|
||||||
update *version.Update
|
update *version.Update
|
||||||
@@ -139,8 +140,11 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
|||||||
go metricsWorker.Run(srvCtx)
|
go metricsWorker.Run(srvCtx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run afterInit hooks before starting any servers
|
// Eagerly create the gRPC server so that all AfterInit hooks are registered
|
||||||
// This allows registering additional gRPC services (e.g., Signal) before Serve() is called
|
// before we iterate them. Lazy creation after the loop would miss hooks
|
||||||
|
// registered during GRPCServer() construction (e.g., SetProxyManager).
|
||||||
|
s.GRPCServer()
|
||||||
|
|
||||||
for _, fn := range s.afterInit {
|
for _, fn := range s.afterInit {
|
||||||
if fn != nil {
|
if fn != nil {
|
||||||
fn(s)
|
fn(s)
|
||||||
@@ -218,6 +222,11 @@ func (s *BaseServer) Stop() error {
|
|||||||
_ = s.certManager.Listener().Close()
|
_ = s.certManager.Listener().Close()
|
||||||
}
|
}
|
||||||
s.GRPCServer().Stop()
|
s.GRPCServer().Stop()
|
||||||
|
s.ReverseProxyGRPCServer().Close()
|
||||||
|
if s.proxyAuthClose != nil {
|
||||||
|
s.proxyAuthClose()
|
||||||
|
s.proxyAuthClose = nil
|
||||||
|
}
|
||||||
_ = s.Store().Close(ctx)
|
_ = s.Store().Close(ctx)
|
||||||
_ = s.EventStore().Close(ctx)
|
_ = s.EventStore().Close(ctx)
|
||||||
if s.update != nil {
|
if s.update != nil {
|
||||||
|
|||||||
167
management/internals/shared/grpc/onetime_token.go
Normal file
167
management/internals/shared/grpc/onetime_token.go
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OneTimeTokenStore manages short-lived, single-use authentication tokens
|
||||||
|
// for proxy-to-management RPC authentication. Tokens are generated when
|
||||||
|
// a service is created and must be used exactly once by the proxy
|
||||||
|
// to authenticate a subsequent RPC call.
|
||||||
|
type OneTimeTokenStore struct {
|
||||||
|
tokens map[string]*tokenMetadata
|
||||||
|
mu sync.RWMutex
|
||||||
|
cleanup *time.Ticker
|
||||||
|
cleanupDone chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tokenMetadata stores information about a one-time token
|
||||||
|
type tokenMetadata struct {
|
||||||
|
ServiceID string
|
||||||
|
AccountID string
|
||||||
|
ExpiresAt time.Time
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOneTimeTokenStore creates a new token store with automatic cleanup
|
||||||
|
// of expired tokens. The cleanupInterval determines how often expired
|
||||||
|
// tokens are removed from memory.
|
||||||
|
func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
|
||||||
|
store := &OneTimeTokenStore{
|
||||||
|
tokens: make(map[string]*tokenMetadata),
|
||||||
|
cleanup: time.NewTicker(cleanupInterval),
|
||||||
|
cleanupDone: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start background cleanup goroutine
|
||||||
|
go store.cleanupExpired()
|
||||||
|
|
||||||
|
return store
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateToken creates a new cryptographically secure one-time token
|
||||||
|
// with the specified TTL. The token is associated with a specific
|
||||||
|
// accountID and serviceID for validation purposes.
|
||||||
|
//
|
||||||
|
// Returns the generated token string or an error if random generation fails.
|
||||||
|
func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) {
|
||||||
|
// Generate 32 bytes (256 bits) of cryptographically secure random data
|
||||||
|
randomBytes := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(randomBytes); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to generate random token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode as URL-safe base64 for easy transmission in gRPC
|
||||||
|
token := base64.URLEncoding.EncodeToString(randomBytes)
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
s.tokens[token] = &tokenMetadata{
|
||||||
|
ServiceID: serviceID,
|
||||||
|
AccountID: accountID,
|
||||||
|
ExpiresAt: time.Now().Add(ttl),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)",
|
||||||
|
serviceID, accountID, ttl)
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateAndConsume verifies the token against the provided accountID and
|
||||||
|
// serviceID, checks expiration, and then deletes it to enforce single-use.
|
||||||
|
//
|
||||||
|
// This method uses constant-time comparison to prevent timing attacks.
|
||||||
|
//
|
||||||
|
// Returns nil on success, or an error if:
|
||||||
|
// - Token doesn't exist
|
||||||
|
// - Token has expired
|
||||||
|
// - Account ID doesn't match
|
||||||
|
// - Reverse proxy ID doesn't match
|
||||||
|
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
metadata, exists := s.tokens[token]
|
||||||
|
if !exists {
|
||||||
|
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)",
|
||||||
|
serviceID, accountID)
|
||||||
|
return fmt.Errorf("invalid token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check expiration
|
||||||
|
if time.Now().After(metadata.ExpiresAt) {
|
||||||
|
delete(s.tokens, token)
|
||||||
|
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)",
|
||||||
|
serviceID, accountID)
|
||||||
|
return fmt.Errorf("token expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate account ID using constant-time comparison (prevents timing attacks)
|
||||||
|
if subtle.ConstantTimeCompare([]byte(metadata.AccountID), []byte(accountID)) != 1 {
|
||||||
|
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)",
|
||||||
|
metadata.AccountID, accountID)
|
||||||
|
return fmt.Errorf("account ID mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate service ID using constant-time comparison
|
||||||
|
if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 {
|
||||||
|
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)",
|
||||||
|
metadata.ServiceID, serviceID)
|
||||||
|
return fmt.Errorf("service ID mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete token immediately to enforce single-use
|
||||||
|
delete(s.tokens, token)
|
||||||
|
|
||||||
|
log.Infof("Token validated and consumed for proxy %s in account %s",
|
||||||
|
serviceID, accountID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupExpired removes expired tokens in the background to prevent memory leaks
|
||||||
|
func (s *OneTimeTokenStore) cleanupExpired() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.cleanup.C:
|
||||||
|
s.mu.Lock()
|
||||||
|
now := time.Now()
|
||||||
|
removed := 0
|
||||||
|
for token, metadata := range s.tokens {
|
||||||
|
if now.After(metadata.ExpiresAt) {
|
||||||
|
delete(s.tokens, token)
|
||||||
|
removed++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if removed > 0 {
|
||||||
|
log.Debugf("Cleaned up %d expired one-time tokens", removed)
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
case <-s.cleanupDone:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the cleanup goroutine and releases resources
|
||||||
|
func (s *OneTimeTokenStore) Close() {
|
||||||
|
s.cleanup.Stop()
|
||||||
|
close(s.cleanupDone)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTokenCount returns the current number of tokens in the store (for debugging/metrics)
|
||||||
|
func (s *OneTimeTokenStore) GetTokenCount() int {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return len(s.tokens)
|
||||||
|
}
|
||||||
1083
management/internals/shared/grpc/proxy.go
Normal file
1083
management/internals/shared/grpc/proxy.go
Normal file
File diff suppressed because it is too large
Load Diff
234
management/internals/shared/grpc/proxy_auth.go
Normal file
234
management/internals/shared/grpc/proxy_auth.go
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// lastUsedUpdateInterval is the minimum interval between last_used updates for the same token.
|
||||||
|
lastUsedUpdateInterval = time.Minute
|
||||||
|
// lastUsedCleanupInterval is how often stale lastUsed entries are removed.
|
||||||
|
lastUsedCleanupInterval = 2 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
type proxyTokenContextKey struct{}
|
||||||
|
|
||||||
|
// ProxyTokenContextKey is the typed key used to store validated token info in context.
|
||||||
|
var ProxyTokenContextKey = proxyTokenContextKey{}
|
||||||
|
|
||||||
|
// proxyTokenID identifies a proxy access token by its database ID.
|
||||||
|
type proxyTokenID = string
|
||||||
|
|
||||||
|
// proxyTokenStore defines the store interface needed for token validation
|
||||||
|
type proxyTokenStore interface {
|
||||||
|
GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength store.LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error)
|
||||||
|
MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// proxyAuthInterceptor holds state for proxy authentication interceptors.
|
||||||
|
type proxyAuthInterceptor struct {
|
||||||
|
store proxyTokenStore
|
||||||
|
failureLimiter *authFailureLimiter
|
||||||
|
|
||||||
|
// lastUsedMu protects lastUsedTimes
|
||||||
|
lastUsedMu sync.Mutex
|
||||||
|
lastUsedTimes map[proxyTokenID]time.Time
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func newProxyAuthInterceptor(tokenStore proxyTokenStore) *proxyAuthInterceptor {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
i := &proxyAuthInterceptor{
|
||||||
|
store: tokenStore,
|
||||||
|
failureLimiter: newAuthFailureLimiter(),
|
||||||
|
lastUsedTimes: make(map[proxyTokenID]time.Time),
|
||||||
|
cancel: cancel,
|
||||||
|
}
|
||||||
|
go i.lastUsedCleanupLoop(ctx)
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProxyAuthInterceptors creates gRPC unary and stream interceptors that validate proxy access tokens.
|
||||||
|
// They only intercept ProxyService methods. Both interceptors share state for last-used and failure rate limiting.
|
||||||
|
// The returned close function must be called on shutdown to stop background goroutines.
|
||||||
|
func NewProxyAuthInterceptors(tokenStore proxyTokenStore) (grpc.UnaryServerInterceptor, grpc.StreamServerInterceptor, func()) {
|
||||||
|
interceptor := newProxyAuthInterceptor(tokenStore)
|
||||||
|
|
||||||
|
unary := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||||
|
if !strings.HasPrefix(info.FullMethod, "/management.ProxyService/") {
|
||||||
|
return handler(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := interceptor.validateProxyToken(ctx)
|
||||||
|
if err != nil {
|
||||||
|
// Log auth failures explicitly; gRPC doesn't log these by default.
|
||||||
|
log.WithContext(ctx).Warnf("proxy auth failed: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = context.WithValue(ctx, ProxyTokenContextKey, token)
|
||||||
|
return handler(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
stream := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||||
|
if !strings.HasPrefix(info.FullMethod, "/management.ProxyService/") {
|
||||||
|
return handler(srv, ss)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := interceptor.validateProxyToken(ss.Context())
|
||||||
|
if err != nil {
|
||||||
|
// Log auth failures explicitly; gRPC doesn't log these by default.
|
||||||
|
log.WithContext(ss.Context()).Warnf("proxy auth failed: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(ss.Context(), ProxyTokenContextKey, token)
|
||||||
|
wrapped := &wrappedServerStream{
|
||||||
|
ServerStream: ss,
|
||||||
|
ctx: ctx,
|
||||||
|
}
|
||||||
|
|
||||||
|
return handler(srv, wrapped)
|
||||||
|
}
|
||||||
|
|
||||||
|
return unary, stream, interceptor.close
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *proxyAuthInterceptor) validateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
|
||||||
|
clientIP := peerIPFromContext(ctx)
|
||||||
|
|
||||||
|
if clientIP != "" && i.failureLimiter.isLimited(clientIP) {
|
||||||
|
return nil, status.Errorf(codes.ResourceExhausted, "too many failed authentication attempts")
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := i.doValidateProxyToken(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if clientIP != "" {
|
||||||
|
i.failureLimiter.recordFailure(clientIP)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
i.maybeUpdateLastUsed(ctx, token.ID)
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *proxyAuthInterceptor) doValidateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
|
||||||
|
md, ok := metadata.FromIncomingContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "missing metadata")
|
||||||
|
}
|
||||||
|
|
||||||
|
authValues := md.Get("authorization")
|
||||||
|
if len(authValues) == 0 {
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "missing authorization header")
|
||||||
|
}
|
||||||
|
|
||||||
|
authValue := authValues[0]
|
||||||
|
if !strings.HasPrefix(authValue, "Bearer ") {
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "invalid authorization format")
|
||||||
|
}
|
||||||
|
|
||||||
|
plainToken := types.PlainProxyToken(strings.TrimPrefix(authValue, "Bearer "))
|
||||||
|
|
||||||
|
if err := plainToken.Validate(); err != nil {
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "invalid token format")
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := i.store.GetProxyAccessTokenByHashedToken(ctx, store.LockingStrengthNone, plainToken.Hash())
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "invalid token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Enforce AccountID scope for "bring your own proxy" feature.
|
||||||
|
// Currently tokens are management-wide; AccountID field is reserved for future use.
|
||||||
|
|
||||||
|
if !token.IsValid() {
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked")
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// maybeUpdateLastUsed updates the last_used timestamp if enough time has passed since the last update.
|
||||||
|
func (i *proxyAuthInterceptor) maybeUpdateLastUsed(ctx context.Context, tokenID string) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
i.lastUsedMu.Lock()
|
||||||
|
lastUpdate, exists := i.lastUsedTimes[tokenID]
|
||||||
|
if exists && now.Sub(lastUpdate) < lastUsedUpdateInterval {
|
||||||
|
i.lastUsedMu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
i.lastUsedTimes[tokenID] = now
|
||||||
|
i.lastUsedMu.Unlock()
|
||||||
|
|
||||||
|
if err := i.store.MarkProxyAccessTokenUsed(ctx, tokenID); err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("failed to mark proxy token as used: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *proxyAuthInterceptor) lastUsedCleanupLoop(ctx context.Context) {
|
||||||
|
ticker := time.NewTicker(lastUsedCleanupInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
i.cleanupStaleLastUsed()
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupStaleLastUsed removes entries older than 2x the update interval.
|
||||||
|
func (i *proxyAuthInterceptor) cleanupStaleLastUsed() {
|
||||||
|
i.lastUsedMu.Lock()
|
||||||
|
defer i.lastUsedMu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
staleThreshold := 2 * lastUsedUpdateInterval
|
||||||
|
for id, lastUpdate := range i.lastUsedTimes {
|
||||||
|
if now.Sub(lastUpdate) > staleThreshold {
|
||||||
|
delete(i.lastUsedTimes, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *proxyAuthInterceptor) close() {
|
||||||
|
i.cancel()
|
||||||
|
i.failureLimiter.stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProxyTokenFromContext retrieves the validated proxy token from the context
|
||||||
|
func GetProxyTokenFromContext(ctx context.Context) *types.ProxyAccessToken {
|
||||||
|
token, ok := ctx.Value(ProxyTokenContextKey).(*types.ProxyAccessToken)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrappedServerStream wraps a grpc.ServerStream to provide a custom context
|
||||||
|
type wrappedServerStream struct {
|
||||||
|
grpc.ServerStream
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wrappedServerStream) Context() context.Context {
|
||||||
|
return w.ctx
|
||||||
|
}
|
||||||
134
management/internals/shared/grpc/proxy_auth_ratelimit.go
Normal file
134
management/internals/shared/grpc/proxy_auth_ratelimit.go
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
"google.golang.org/grpc/peer"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// proxyAuthFailureBurst is the maximum number of failed attempts before rate limiting kicks in.
|
||||||
|
proxyAuthFailureBurst = 5
|
||||||
|
// proxyAuthLimiterCleanup is how often stale limiters are removed.
|
||||||
|
proxyAuthLimiterCleanup = 5 * time.Minute
|
||||||
|
// proxyAuthLimiterTTL is how long a limiter is kept after the last failure.
|
||||||
|
proxyAuthLimiterTTL = 15 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// defaultProxyAuthFailureRate is the token replenishment rate for failed auth attempts.
|
||||||
|
// One token every 12 seconds = 5 per minute.
|
||||||
|
var defaultProxyAuthFailureRate = rate.Every(12 * time.Second)
|
||||||
|
|
||||||
|
// clientIP identifies a client by its IP address for rate limiting purposes.
|
||||||
|
type clientIP = string
|
||||||
|
|
||||||
|
type limiterEntry struct {
|
||||||
|
limiter *rate.Limiter
|
||||||
|
lastAccess time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// authFailureLimiter tracks per-IP rate limits for failed proxy authentication attempts.
|
||||||
|
type authFailureLimiter struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
limiters map[clientIP]*limiterEntry
|
||||||
|
failureRate rate.Limit
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAuthFailureLimiter() *authFailureLimiter {
|
||||||
|
return newAuthFailureLimiterWithRate(defaultProxyAuthFailureRate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAuthFailureLimiterWithRate(failureRate rate.Limit) *authFailureLimiter {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
l := &authFailureLimiter{
|
||||||
|
limiters: make(map[clientIP]*limiterEntry),
|
||||||
|
failureRate: failureRate,
|
||||||
|
cancel: cancel,
|
||||||
|
}
|
||||||
|
go l.cleanupLoop(ctx)
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
// isLimited returns true if the given IP has exhausted its failure budget.
|
||||||
|
func (l *authFailureLimiter) isLimited(ip clientIP) bool {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
|
||||||
|
entry, exists := l.limiters[ip]
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return entry.limiter.Tokens() < 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// recordFailure consumes a token from the rate limiter for the given IP.
|
||||||
|
func (l *authFailureLimiter) recordFailure(ip clientIP) {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
entry, exists := l.limiters[ip]
|
||||||
|
if !exists {
|
||||||
|
entry = &limiterEntry{
|
||||||
|
limiter: rate.NewLimiter(l.failureRate, proxyAuthFailureBurst),
|
||||||
|
}
|
||||||
|
l.limiters[ip] = entry
|
||||||
|
}
|
||||||
|
entry.lastAccess = now
|
||||||
|
entry.limiter.Allow()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *authFailureLimiter) cleanupLoop(ctx context.Context) {
|
||||||
|
ticker := time.NewTicker(proxyAuthLimiterCleanup)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
l.cleanup()
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *authFailureLimiter) cleanup() {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for ip, entry := range l.limiters {
|
||||||
|
if now.Sub(entry.lastAccess) > proxyAuthLimiterTTL {
|
||||||
|
delete(l.limiters, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *authFailureLimiter) stop() {
|
||||||
|
l.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// peerIPFromContext extracts the client IP from the gRPC context.
|
||||||
|
// Uses realip (from trusted proxy headers) first, falls back to the transport peer address.
|
||||||
|
func peerIPFromContext(ctx context.Context) clientIP {
|
||||||
|
if addr, ok := realip.FromContext(ctx); ok {
|
||||||
|
return addr.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
if p, ok := peer.FromContext(ctx); ok {
|
||||||
|
host, _, err := net.SplitHostPort(p.Addr.String())
|
||||||
|
if err != nil {
|
||||||
|
return p.Addr.String()
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
@@ -0,0 +1,98 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthFailureLimiter_NotLimitedInitially(t *testing.T) {
|
||||||
|
l := newAuthFailureLimiter()
|
||||||
|
defer l.stop()
|
||||||
|
|
||||||
|
assert.False(t, l.isLimited("192.168.1.1"), "new IP should not be rate limited")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthFailureLimiter_LimitedAfterBurst(t *testing.T) {
|
||||||
|
l := newAuthFailureLimiter()
|
||||||
|
defer l.stop()
|
||||||
|
|
||||||
|
ip := "192.168.1.1"
|
||||||
|
for i := 0; i < proxyAuthFailureBurst; i++ {
|
||||||
|
l.recordFailure(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, l.isLimited(ip), "IP should be limited after exhausting burst")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthFailureLimiter_DifferentIPsIndependent(t *testing.T) {
|
||||||
|
l := newAuthFailureLimiter()
|
||||||
|
defer l.stop()
|
||||||
|
|
||||||
|
for i := 0; i < proxyAuthFailureBurst; i++ {
|
||||||
|
l.recordFailure("192.168.1.1")
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, l.isLimited("192.168.1.1"))
|
||||||
|
assert.False(t, l.isLimited("192.168.1.2"), "different IP should not be affected")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthFailureLimiter_RecoveryOverTime(t *testing.T) {
|
||||||
|
l := newAuthFailureLimiterWithRate(rate.Limit(100)) // 100 tokens/sec for fast recovery
|
||||||
|
defer l.stop()
|
||||||
|
|
||||||
|
ip := "10.0.0.1"
|
||||||
|
|
||||||
|
// Exhaust burst
|
||||||
|
for i := 0; i < proxyAuthFailureBurst; i++ {
|
||||||
|
l.recordFailure(ip)
|
||||||
|
}
|
||||||
|
require.True(t, l.isLimited(ip))
|
||||||
|
|
||||||
|
// Wait for token replenishment
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
assert.False(t, l.isLimited(ip), "should recover after tokens replenish")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthFailureLimiter_Cleanup(t *testing.T) {
|
||||||
|
l := newAuthFailureLimiter()
|
||||||
|
defer l.stop()
|
||||||
|
|
||||||
|
l.recordFailure("10.0.0.1")
|
||||||
|
|
||||||
|
l.mu.Lock()
|
||||||
|
require.Len(t, l.limiters, 1)
|
||||||
|
// Backdate the entry so it looks stale
|
||||||
|
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
|
||||||
|
l.mu.Unlock()
|
||||||
|
|
||||||
|
l.cleanup()
|
||||||
|
|
||||||
|
l.mu.Lock()
|
||||||
|
assert.Empty(t, l.limiters, "stale entries should be cleaned up")
|
||||||
|
l.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthFailureLimiter_CleanupKeepsFresh(t *testing.T) {
|
||||||
|
l := newAuthFailureLimiter()
|
||||||
|
defer l.stop()
|
||||||
|
|
||||||
|
l.recordFailure("10.0.0.1")
|
||||||
|
l.recordFailure("10.0.0.2")
|
||||||
|
|
||||||
|
l.mu.Lock()
|
||||||
|
// Only backdate one entry
|
||||||
|
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
|
||||||
|
l.mu.Unlock()
|
||||||
|
|
||||||
|
l.cleanup()
|
||||||
|
|
||||||
|
l.mu.Lock()
|
||||||
|
assert.Len(t, l.limiters, 1, "only stale entries should be removed")
|
||||||
|
assert.Contains(t, l.limiters, "10.0.0.2")
|
||||||
|
l.mu.Unlock()
|
||||||
|
}
|
||||||
381
management/internals/shared/grpc/proxy_group_access_test.go
Normal file
381
management/internals/shared/grpc/proxy_group_access_test.go
Normal file
@@ -0,0 +1,381 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockReverseProxyManager struct {
|
||||||
|
proxiesByAccount map[string][]*reverseproxy.Service
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||||
|
if m.err != nil {
|
||||||
|
return nil, m.err
|
||||||
|
}
|
||||||
|
return m.proxiesByAccount[accountID], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
|
||||||
|
return []*reverseproxy.Service{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.Service, error) {
|
||||||
|
return &reverseproxy.Service{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||||
|
return &reverseproxy.Service{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||||
|
return &reverseproxy.Service{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) ReloadService(ctx context.Context, accountID, reverseProxyID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.Service, error) {
|
||||||
|
return &reverseproxy.Service{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockUsersManager struct {
|
||||||
|
users map[string]*types.User
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
|
||||||
|
if m.err != nil {
|
||||||
|
return nil, m.err
|
||||||
|
}
|
||||||
|
user, ok := m.users[userID]
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("user not found")
|
||||||
|
}
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateUserGroupAccess(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
domain string
|
||||||
|
userID string
|
||||||
|
proxiesByAccount map[string][]*reverseproxy.Service
|
||||||
|
users map[string]*types.User
|
||||||
|
proxyErr error
|
||||||
|
userErr error
|
||||||
|
expectErr bool
|
||||||
|
expectErrMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "user not found",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "unknown-user",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{},
|
||||||
|
expectErr: true,
|
||||||
|
expectErrMsg: "user not found",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "proxy not found in user's account",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
expectErrMsg: "service not found",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "proxy exists in different account - not accessible",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account2": {{Domain: "app.example.com", AccountID: "account2"}},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
expectErrMsg: "service not found",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no bearer auth configured - same account allows access",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bearer auth disabled - same account allows access",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {{
|
||||||
|
Domain: "app.example.com",
|
||||||
|
AccountID: "account1",
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: false},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bearer auth enabled but no groups configured - same account allows access",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {{
|
||||||
|
Domain: "app.example.com",
|
||||||
|
AccountID: "account1",
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
DistributionGroups: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user not in allowed groups",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {{
|
||||||
|
Domain: "app.example.com",
|
||||||
|
AccountID: "account1",
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
DistributionGroups: []string{"group1", "group2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group3", "group4"}},
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
expectErrMsg: "not in allowed groups",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user in one of the allowed groups - allow access",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {{
|
||||||
|
Domain: "app.example.com",
|
||||||
|
AccountID: "account1",
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
DistributionGroups: []string{"group1", "group2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group2", "group3"}},
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user in all allowed groups - allow access",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {{
|
||||||
|
Domain: "app.example.com",
|
||||||
|
AccountID: "account1",
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
DistributionGroups: []string{"group1", "group2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group1", "group2", "group3"}},
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "proxy manager error",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: nil,
|
||||||
|
proxyErr: errors.New("database error"),
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
expectErrMsg: "get account services",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple proxies in account - finds correct one",
|
||||||
|
domain: "app2.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {
|
||||||
|
{Domain: "app1.example.com", AccountID: "account1"},
|
||||||
|
{Domain: "app2.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}},
|
||||||
|
{Domain: "app3.example.com", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
server := &ProxyServiceServer{
|
||||||
|
reverseProxyManager: &mockReverseProxyManager{
|
||||||
|
proxiesByAccount: tt.proxiesByAccount,
|
||||||
|
err: tt.proxyErr,
|
||||||
|
},
|
||||||
|
usersManager: &mockUsersManager{
|
||||||
|
users: tt.users,
|
||||||
|
err: tt.userErr,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := server.ValidateUserGroupAccess(context.Background(), tt.domain, tt.userID)
|
||||||
|
|
||||||
|
if tt.expectErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tt.expectErrMsg)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAccountProxyByDomain(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
accountID string
|
||||||
|
domain string
|
||||||
|
proxiesByAccount map[string][]*reverseproxy.Service
|
||||||
|
err error
|
||||||
|
expectProxy bool
|
||||||
|
expectErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "proxy found",
|
||||||
|
accountID: "account1",
|
||||||
|
domain: "app.example.com",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {
|
||||||
|
{Domain: "other.example.com", AccountID: "account1"},
|
||||||
|
{Domain: "app.example.com", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectProxy: true,
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "proxy not found in account",
|
||||||
|
accountID: "account1",
|
||||||
|
domain: "unknown.example.com",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
|
||||||
|
},
|
||||||
|
expectProxy: false,
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty proxy list for account",
|
||||||
|
accountID: "account1",
|
||||||
|
domain: "app.example.com",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{},
|
||||||
|
expectProxy: false,
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "manager error",
|
||||||
|
accountID: "account1",
|
||||||
|
domain: "app.example.com",
|
||||||
|
proxiesByAccount: nil,
|
||||||
|
err: errors.New("database error"),
|
||||||
|
expectProxy: false,
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
server := &ProxyServiceServer{
|
||||||
|
reverseProxyManager: &mockReverseProxyManager{
|
||||||
|
proxiesByAccount: tt.proxiesByAccount,
|
||||||
|
err: tt.err,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy, err := server.getAccountServiceByDomain(context.Background(), tt.accountID, tt.domain)
|
||||||
|
|
||||||
|
if tt.expectErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, proxy)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, proxy)
|
||||||
|
assert.Equal(t, tt.domain, proxy.Domain)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
232
management/internals/shared/grpc/proxy_test.go
Normal file
232
management/internals/shared/grpc/proxy_test.go
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// registerFakeProxy adds a fake proxy connection to the server's internal maps
|
||||||
|
// and returns the channel where messages will be received.
|
||||||
|
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.ProxyMapping {
|
||||||
|
ch := make(chan *proto.ProxyMapping, 10)
|
||||||
|
conn := &proxyConnection{
|
||||||
|
proxyID: proxyID,
|
||||||
|
address: clusterAddr,
|
||||||
|
sendChan: ch,
|
||||||
|
}
|
||||||
|
s.connectedProxies.Store(proxyID, conn)
|
||||||
|
|
||||||
|
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
|
||||||
|
proxySet.(*sync.Map).Store(proxyID, struct{}{})
|
||||||
|
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
|
func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping {
|
||||||
|
select {
|
||||||
|
case msg := <-ch:
|
||||||
|
return msg
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
||||||
|
tokenStore := NewOneTimeTokenStore(time.Hour)
|
||||||
|
defer tokenStore.Close()
|
||||||
|
|
||||||
|
s := &ProxyServiceServer{
|
||||||
|
tokenStore: tokenStore,
|
||||||
|
updatesChan: make(chan *proto.ProxyMapping, 100),
|
||||||
|
}
|
||||||
|
|
||||||
|
const cluster = "proxy.example.com"
|
||||||
|
const numProxies = 3
|
||||||
|
|
||||||
|
channels := make([]chan *proto.ProxyMapping, numProxies)
|
||||||
|
for i := range numProxies {
|
||||||
|
id := "proxy-" + string(rune('a'+i))
|
||||||
|
channels[i] = registerFakeProxy(s, id, cluster)
|
||||||
|
}
|
||||||
|
|
||||||
|
update := &proto.ProxyMapping{
|
||||||
|
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||||
|
Id: "service-1",
|
||||||
|
AccountId: "account-1",
|
||||||
|
Domain: "test.example.com",
|
||||||
|
Path: []*proto.PathMapping{
|
||||||
|
{Path: "/", Target: "http://10.0.0.1:8080/"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s.SendServiceUpdateToCluster(update, cluster)
|
||||||
|
|
||||||
|
tokens := make([]string, numProxies)
|
||||||
|
for i, ch := range channels {
|
||||||
|
msg := drainChannel(ch)
|
||||||
|
require.NotNil(t, msg, "proxy %d should receive a message", i)
|
||||||
|
assert.Equal(t, update.Domain, msg.Domain)
|
||||||
|
assert.Equal(t, update.Id, msg.Id)
|
||||||
|
assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i)
|
||||||
|
tokens[i] = msg.AuthToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// All tokens must be unique
|
||||||
|
tokenSet := make(map[string]struct{})
|
||||||
|
for i, tok := range tokens {
|
||||||
|
_, exists := tokenSet[tok]
|
||||||
|
assert.False(t, exists, "proxy %d got duplicate token", i)
|
||||||
|
tokenSet[tok] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Each token must be independently consumable
|
||||||
|
for i, tok := range tokens {
|
||||||
|
err := tokenStore.ValidateAndConsume(tok, "account-1", "service-1")
|
||||||
|
assert.NoError(t, err, "proxy %d token should validate successfully", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
||||||
|
tokenStore := NewOneTimeTokenStore(time.Hour)
|
||||||
|
defer tokenStore.Close()
|
||||||
|
|
||||||
|
s := &ProxyServiceServer{
|
||||||
|
tokenStore: tokenStore,
|
||||||
|
updatesChan: make(chan *proto.ProxyMapping, 100),
|
||||||
|
}
|
||||||
|
|
||||||
|
const cluster = "proxy.example.com"
|
||||||
|
ch1 := registerFakeProxy(s, "proxy-a", cluster)
|
||||||
|
ch2 := registerFakeProxy(s, "proxy-b", cluster)
|
||||||
|
|
||||||
|
update := &proto.ProxyMapping{
|
||||||
|
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
|
||||||
|
Id: "service-1",
|
||||||
|
AccountId: "account-1",
|
||||||
|
Domain: "test.example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
s.SendServiceUpdateToCluster(update, cluster)
|
||||||
|
|
||||||
|
msg1 := drainChannel(ch1)
|
||||||
|
msg2 := drainChannel(ch2)
|
||||||
|
require.NotNil(t, msg1)
|
||||||
|
require.NotNil(t, msg2)
|
||||||
|
|
||||||
|
// Delete operations should not generate tokens
|
||||||
|
assert.Empty(t, msg1.AuthToken)
|
||||||
|
assert.Empty(t, msg2.AuthToken)
|
||||||
|
|
||||||
|
// No tokens should have been created
|
||||||
|
assert.Equal(t, 0, tokenStore.GetTokenCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
|
||||||
|
tokenStore := NewOneTimeTokenStore(time.Hour)
|
||||||
|
defer tokenStore.Close()
|
||||||
|
|
||||||
|
s := &ProxyServiceServer{
|
||||||
|
tokenStore: tokenStore,
|
||||||
|
updatesChan: make(chan *proto.ProxyMapping, 100),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register proxies in different clusters (SendServiceUpdate broadcasts to all)
|
||||||
|
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")
|
||||||
|
ch2 := registerFakeProxy(s, "proxy-b", "cluster-b")
|
||||||
|
|
||||||
|
update := &proto.ProxyMapping{
|
||||||
|
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||||
|
Id: "service-1",
|
||||||
|
AccountId: "account-1",
|
||||||
|
Domain: "test.example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
s.SendServiceUpdate(update)
|
||||||
|
|
||||||
|
msg1 := drainChannel(ch1)
|
||||||
|
msg2 := drainChannel(ch2)
|
||||||
|
require.NotNil(t, msg1)
|
||||||
|
require.NotNil(t, msg2)
|
||||||
|
|
||||||
|
assert.NotEmpty(t, msg1.AuthToken)
|
||||||
|
assert.NotEmpty(t, msg2.AuthToken)
|
||||||
|
assert.NotEqual(t, msg1.AuthToken, msg2.AuthToken, "tokens must be unique per proxy")
|
||||||
|
|
||||||
|
// Both tokens should validate
|
||||||
|
assert.NoError(t, tokenStore.ValidateAndConsume(msg1.AuthToken, "account-1", "service-1"))
|
||||||
|
assert.NoError(t, tokenStore.ValidateAndConsume(msg2.AuthToken, "account-1", "service-1"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateState creates a state using the same format as GetOIDCURL.
|
||||||
|
func generateState(s *ProxyServiceServer, redirectURL string) string {
|
||||||
|
nonce := make([]byte, 16)
|
||||||
|
_, _ = rand.Read(nonce)
|
||||||
|
nonceB64 := base64.URLEncoding.EncodeToString(nonce)
|
||||||
|
|
||||||
|
payload := redirectURL + "|" + nonceB64
|
||||||
|
hmacSum := s.generateHMAC(payload)
|
||||||
|
return base64.URLEncoding.EncodeToString([]byte(redirectURL)) + "|" + nonceB64 + "|" + hmacSum
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOAuthState_NeverTheSame(t *testing.T) {
|
||||||
|
s := &ProxyServiceServer{
|
||||||
|
oidcConfig: ProxyOIDCConfig{
|
||||||
|
HMACKey: []byte("test-hmac-key"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURL := "https://app.example.com/callback"
|
||||||
|
|
||||||
|
// Generate 100 states for the same redirect URL
|
||||||
|
states := make(map[string]bool)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
state := generateState(s, redirectURL)
|
||||||
|
|
||||||
|
// State must have 3 parts: base64(url)|nonce|hmac
|
||||||
|
parts := strings.Split(state, "|")
|
||||||
|
require.Equal(t, 3, len(parts), "state must have 3 parts")
|
||||||
|
|
||||||
|
// State must be unique
|
||||||
|
require.False(t, states[state], "state %d is a duplicate", i)
|
||||||
|
states[state] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
|
||||||
|
s := &ProxyServiceServer{
|
||||||
|
oidcConfig: ProxyOIDCConfig{
|
||||||
|
HMACKey: []byte("test-hmac-key"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Old format had only 2 parts: base64(url)|hmac
|
||||||
|
s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
|
||||||
|
|
||||||
|
_, _, err := s.ValidateState("base64url|hmac")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid state format")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
|
||||||
|
s := &ProxyServiceServer{
|
||||||
|
oidcConfig: ProxyOIDCConfig{
|
||||||
|
HMACKey: []byte("test-hmac-key"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store with tampered HMAC
|
||||||
|
s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
|
||||||
|
|
||||||
|
_, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid state signature")
|
||||||
|
}
|
||||||
304
management/internals/shared/grpc/validate_session_test.go
Normal file
304
management/internals/shared/grpc/validate_session_test.go
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/proxy/auth"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type validateSessionTestSetup struct {
|
||||||
|
proxyService *ProxyServiceServer
|
||||||
|
store store.Store
|
||||||
|
cleanup func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
proxyManager := &testValidateSessionProxyManager{store: testStore}
|
||||||
|
usersManager := &testValidateSessionUsersManager{store: testStore}
|
||||||
|
|
||||||
|
proxyService := NewProxyServiceServer(nil, NewOneTimeTokenStore(time.Minute), ProxyOIDCConfig{}, nil, usersManager)
|
||||||
|
proxyService.SetProxyManager(proxyManager)
|
||||||
|
|
||||||
|
createTestProxies(t, ctx, testStore)
|
||||||
|
|
||||||
|
return &validateSessionTestSetup{
|
||||||
|
proxyService: proxyService,
|
||||||
|
store: testStore,
|
||||||
|
cleanup: storeCleanup,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
pubKey, privKey := generateSessionKeyPair(t)
|
||||||
|
|
||||||
|
testProxy := &reverseproxy.Service{
|
||||||
|
ID: "testProxyId",
|
||||||
|
AccountID: "testAccountId",
|
||||||
|
Name: "Test Proxy",
|
||||||
|
Domain: "test-proxy.example.com",
|
||||||
|
Enabled: true,
|
||||||
|
SessionPrivateKey: privKey,
|
||||||
|
SessionPublicKey: pubKey,
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.NoError(t, testStore.CreateService(ctx, testProxy))
|
||||||
|
|
||||||
|
restrictedProxy := &reverseproxy.Service{
|
||||||
|
ID: "restrictedProxyId",
|
||||||
|
AccountID: "testAccountId",
|
||||||
|
Name: "Restricted Proxy",
|
||||||
|
Domain: "restricted-proxy.example.com",
|
||||||
|
Enabled: true,
|
||||||
|
SessionPrivateKey: privKey,
|
||||||
|
SessionPublicKey: pubKey,
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
DistributionGroups: []string{"allowedGroupId"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateSessionKeyPair(t *testing.T) (string, string) {
|
||||||
|
t.Helper()
|
||||||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return base64.StdEncoding.EncodeToString(pub), base64.StdEncoding.EncodeToString(priv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string {
|
||||||
|
t.Helper()
|
||||||
|
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSession_UserAllowed(t *testing.T) {
|
||||||
|
setup := setupValidateSessionTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "test-proxy.example.com")
|
||||||
|
|
||||||
|
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
||||||
|
Domain: "test-proxy.example.com",
|
||||||
|
SessionToken: token,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, resp.Valid, "User should be allowed access")
|
||||||
|
assert.Equal(t, "allowedUserId", resp.UserId)
|
||||||
|
assert.Empty(t, resp.DeniedReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
|
||||||
|
setup := setupValidateSessionTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "restrictedProxyId")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token := createSessionToken(t, proxy.SessionPrivateKey, "nonGroupUserId", "restricted-proxy.example.com")
|
||||||
|
|
||||||
|
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
||||||
|
Domain: "restricted-proxy.example.com",
|
||||||
|
SessionToken: token,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, resp.Valid, "User not in group should be denied")
|
||||||
|
assert.Equal(t, "not_in_group", resp.DeniedReason)
|
||||||
|
assert.Equal(t, "nonGroupUserId", resp.UserId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSession_UserInDifferentAccount(t *testing.T) {
|
||||||
|
setup := setupValidateSessionTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token := createSessionToken(t, proxy.SessionPrivateKey, "otherAccountUserId", "test-proxy.example.com")
|
||||||
|
|
||||||
|
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
||||||
|
Domain: "test-proxy.example.com",
|
||||||
|
SessionToken: token,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, resp.Valid, "User in different account should be denied")
|
||||||
|
assert.Equal(t, "account_mismatch", resp.DeniedReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSession_UserNotFound(t *testing.T) {
|
||||||
|
setup := setupValidateSessionTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token := createSessionToken(t, proxy.SessionPrivateKey, "nonExistentUserId", "test-proxy.example.com")
|
||||||
|
|
||||||
|
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
||||||
|
Domain: "test-proxy.example.com",
|
||||||
|
SessionToken: token,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, resp.Valid, "Non-existent user should be denied")
|
||||||
|
assert.Equal(t, "user_not_found", resp.DeniedReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSession_ProxyNotFound(t *testing.T) {
|
||||||
|
setup := setupValidateSessionTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "unknown-proxy.example.com")
|
||||||
|
|
||||||
|
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
||||||
|
Domain: "unknown-proxy.example.com",
|
||||||
|
SessionToken: token,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, resp.Valid, "Unknown proxy should be denied")
|
||||||
|
assert.Equal(t, "proxy_not_found", resp.DeniedReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSession_InvalidToken(t *testing.T) {
|
||||||
|
setup := setupValidateSessionTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
||||||
|
Domain: "test-proxy.example.com",
|
||||||
|
SessionToken: "invalid-token",
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, resp.Valid, "Invalid token should be denied")
|
||||||
|
assert.Equal(t, "invalid_token", resp.DeniedReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSession_MissingDomain(t *testing.T) {
|
||||||
|
setup := setupValidateSessionTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
||||||
|
SessionToken: "some-token",
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, resp.Valid)
|
||||||
|
assert.Contains(t, resp.DeniedReason, "missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSession_MissingToken(t *testing.T) {
|
||||||
|
setup := setupValidateSessionTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
||||||
|
Domain: "test-proxy.example.com",
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, resp.Valid)
|
||||||
|
assert.Contains(t, resp.DeniedReason, "missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
type testValidateSessionProxyManager struct {
|
||||||
|
store store.Store
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) DeleteService(_ context.Context, _, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) ReloadService(_ context.Context, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
||||||
|
return m.store.GetServices(ctx, store.LockingStrengthNone)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
|
||||||
|
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||||
|
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type testValidateSessionUsersManager struct {
|
||||||
|
store store.Store
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
|
||||||
|
return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||||
|
}
|
||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
"github.com/netbirdio/netbird/management/server/job"
|
"github.com/netbirdio/netbird/management/server/job"
|
||||||
"github.com/netbirdio/netbird/shared/auth"
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
|
|
||||||
@@ -82,8 +83,9 @@ type DefaultAccountManager struct {
|
|||||||
|
|
||||||
requestBuffer *AccountRequestBuffer
|
requestBuffer *AccountRequestBuffer
|
||||||
|
|
||||||
proxyController port_forwarding.Controller
|
proxyController port_forwarding.Controller
|
||||||
settingsManager settings.Manager
|
settingsManager settings.Manager
|
||||||
|
reverseProxyManager reverseproxy.Manager
|
||||||
|
|
||||||
// config contains the management server configuration
|
// config contains the management server configuration
|
||||||
config *nbconfig.Config
|
config *nbconfig.Config
|
||||||
@@ -113,6 +115,10 @@ type DefaultAccountManager struct {
|
|||||||
|
|
||||||
var _ account.Manager = (*DefaultAccountManager)(nil)
|
var _ account.Manager = (*DefaultAccountManager)(nil)
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
|
||||||
|
am.reverseProxyManager = serviceManager
|
||||||
|
}
|
||||||
|
|
||||||
func isUniqueConstraintError(err error) bool {
|
func isUniqueConstraintError(err error) bool {
|
||||||
switch {
|
switch {
|
||||||
case strings.Contains(err.Error(), "(SQLSTATE 23505)"),
|
case strings.Contains(err.Error(), "(SQLSTATE 23505)"),
|
||||||
@@ -321,6 +327,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
|
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if err = am.reverseProxyManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err)
|
||||||
|
}
|
||||||
updateAccountPeers = true
|
updateAccountPeers = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
"github.com/netbirdio/netbird/shared/auth"
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
@@ -139,4 +140,5 @@ type Manager interface {
|
|||||||
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
|
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
|
||||||
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
|
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
|
||||||
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
|
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
|
||||||
|
SetServiceManager(serviceManager reverseproxy.Manager)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,6 +27,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
|
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||||
@@ -1800,6 +1802,14 @@ func TestAccount_Copy(t *testing.T) {
|
|||||||
Address: "172.12.6.1/24",
|
Address: "172.12.6.1/24",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
Services: []*reverseproxy.Service{
|
||||||
|
{
|
||||||
|
ID: "service1",
|
||||||
|
Name: "test-service",
|
||||||
|
AccountID: "account1",
|
||||||
|
Targets: []*reverseproxy.Target{},
|
||||||
|
},
|
||||||
|
},
|
||||||
NetworkMapCache: &types.NetworkMapBuilder{},
|
NetworkMapCache: &types.NetworkMapBuilder{},
|
||||||
}
|
}
|
||||||
account.InitOnce()
|
account.InitOnce()
|
||||||
@@ -3112,6 +3122,8 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
|||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, nil, nil))
|
||||||
|
|
||||||
return manager, updateManager, nil
|
return manager, updateManager, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -204,6 +204,10 @@ const (
|
|||||||
UserInviteLinkRegenerated Activity = 106
|
UserInviteLinkRegenerated Activity = 106
|
||||||
UserInviteLinkDeleted Activity = 107
|
UserInviteLinkDeleted Activity = 107
|
||||||
|
|
||||||
|
ServiceCreated Activity = 108
|
||||||
|
ServiceUpdated Activity = 109
|
||||||
|
ServiceDeleted Activity = 110
|
||||||
|
|
||||||
AccountDeleted Activity = 99999
|
AccountDeleted Activity = 99999
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -337,6 +341,10 @@ var activityMap = map[Activity]Code{
|
|||||||
UserInviteLinkAccepted: {"User invite link accepted", "user.invite.link.accept"},
|
UserInviteLinkAccepted: {"User invite link accepted", "user.invite.link.accept"},
|
||||||
UserInviteLinkRegenerated: {"User invite link regenerated", "user.invite.link.regenerate"},
|
UserInviteLinkRegenerated: {"User invite link regenerated", "user.invite.link.regenerate"},
|
||||||
UserInviteLinkDeleted: {"User invite link deleted", "user.invite.link.delete"},
|
UserInviteLinkDeleted: {"User invite link deleted", "user.invite.link.delete"},
|
||||||
|
|
||||||
|
ServiceCreated: {"Service created", "service.create"},
|
||||||
|
ServiceUpdated: {"Service updated", "service.update"},
|
||||||
|
ServiceDeleted: {"Service deleted", "service.delete"},
|
||||||
}
|
}
|
||||||
|
|
||||||
// StringCode returns a string code of the activity
|
// StringCode returns a string code of the activity
|
||||||
|
|||||||
@@ -703,7 +703,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
|||||||
t.Run("saving group linked to network router", func(t *testing.T) {
|
t.Run("saving group linked to network router", func(t *testing.T) {
|
||||||
permissionsManager := permissions.NewManager(manager.Store)
|
permissionsManager := permissions.NewManager(manager.Store)
|
||||||
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
|
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
|
||||||
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager)
|
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.reverseProxyManager)
|
||||||
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
|
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
|
||||||
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
|
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
@@ -12,9 +13,19 @@ import (
|
|||||||
"github.com/rs/cors"
|
"github.com/rs/cors"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
|
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
||||||
|
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||||
@@ -26,6 +37,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
|
||||||
|
|
||||||
nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers"
|
nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
"github.com/netbirdio/netbird/management/server/auth"
|
"github.com/netbirdio/netbird/management/server/auth"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
@@ -60,7 +73,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager) (http.Handler, error) {
|
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
|
||||||
|
|
||||||
// Register bypass paths for unauthenticated endpoints
|
// Register bypass paths for unauthenticated endpoints
|
||||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||||
@@ -76,6 +89,10 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
|||||||
if err := bypass.AddBypassPath("/api/users/invites/nbi_*/accept"); err != nil {
|
if err := bypass.AddBypassPath("/api/users/invites/nbi_*/accept"); err != nil {
|
||||||
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
||||||
}
|
}
|
||||||
|
// OAuth callback for proxy authentication
|
||||||
|
if err := bypass.AddBypassPath(types.ProxyCallbackEndpointFull); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
var rateLimitingConfig *middleware.RateLimiterConfig
|
var rateLimitingConfig *middleware.RateLimiterConfig
|
||||||
if os.Getenv(rateLimitingEnabledKey) == "true" {
|
if os.Getenv(rateLimitingEnabledKey) == "true" {
|
||||||
@@ -156,6 +173,15 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
|||||||
idp.AddEndpoints(accountManager, router)
|
idp.AddEndpoints(accountManager, router)
|
||||||
instance.AddEndpoints(instanceManager, router)
|
instance.AddEndpoints(instanceManager, router)
|
||||||
instance.AddVersionEndpoint(instanceManager, router)
|
instance.AddVersionEndpoint(instanceManager, router)
|
||||||
|
if reverseProxyManager != nil && reverseProxyDomainManager != nil {
|
||||||
|
reverseproxymanager.RegisterEndpoints(reverseProxyManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register OAuth callback handler for proxy authentication
|
||||||
|
if proxyGRPCServer != nil {
|
||||||
|
oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies)
|
||||||
|
oauthHandler.RegisterEndpoints(router)
|
||||||
|
}
|
||||||
|
|
||||||
// Mount embedded IdP handler at /oauth2 path if configured
|
// Mount embedded IdP handler at /oauth2 path if configured
|
||||||
if embeddedIdpEnabled {
|
if embeddedIdpEnabled {
|
||||||
|
|||||||
@@ -154,6 +154,11 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if peer.ProxyMeta.Embedded {
|
||||||
|
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "not allowed to read peer"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
|
settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(ctx, err, w)
|
util.WriteError(ctx, err, w)
|
||||||
@@ -321,6 +326,9 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
|||||||
grpsInfoMap := groups.ToGroupsInfoMap(grps, len(peers))
|
grpsInfoMap := groups.ToGroupsInfoMap(grps, len(peers))
|
||||||
respBody := make([]*api.PeerBatch, 0, len(peers))
|
respBody := make([]*api.PeerBatch, 0, len(peers))
|
||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
|
if peer.ProxyMeta.Embedded {
|
||||||
|
continue
|
||||||
|
}
|
||||||
respBody = append(respBody, toPeerListItemResponse(peer, grpsInfoMap[peer.ID], dnsDomain, 0))
|
respBody = append(respBody, toPeerListItemResponse(peer, grpsInfoMap[peer.ID], dnsDomain, 0))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
208
management/server/http/handlers/proxy/auth.go
Normal file
208
management/server/http/handlers/proxy/auth.go
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/proxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthCallbackHandler handles OAuth callbacks for proxy authentication.
|
||||||
|
type AuthCallbackHandler struct {
|
||||||
|
proxyService *nbgrpc.ProxyServiceServer
|
||||||
|
rateLimiter *middleware.APIRateLimiter
|
||||||
|
trustedProxies []netip.Prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuthCallbackHandler creates a new OAuth callback handler.
|
||||||
|
func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer, trustedProxies []netip.Prefix) *AuthCallbackHandler {
|
||||||
|
rateLimiterConfig := &middleware.RateLimiterConfig{
|
||||||
|
RequestsPerMinute: 10,
|
||||||
|
Burst: 15,
|
||||||
|
CleanupInterval: 5 * time.Minute,
|
||||||
|
LimiterTTL: 10 * time.Minute,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &AuthCallbackHandler{
|
||||||
|
proxyService: proxyService,
|
||||||
|
rateLimiter: middleware.NewAPIRateLimiter(rateLimiterConfig),
|
||||||
|
trustedProxies: trustedProxies,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterEndpoints registers the OAuth callback endpoint.
|
||||||
|
func (h *AuthCallbackHandler) RegisterEndpoints(router *mux.Router) {
|
||||||
|
router.HandleFunc(types.ProxyCallbackEndpoint, h.handleCallback).Methods(http.MethodGet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
|
clientIP := h.resolveClientIP(r)
|
||||||
|
if !h.rateLimiter.Allow(clientIP) {
|
||||||
|
log.WithField("client_ip", clientIP).Warn("OAuth callback rate limit exceeded")
|
||||||
|
http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
|
||||||
|
codeVerifier, originalURL, err := h.proxyService.ValidateState(state)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Error("OAuth callback state validation failed")
|
||||||
|
http.Error(w, "Invalid state parameter", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURL, err := url.Parse(originalURL)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Error("Failed to parse redirect URL")
|
||||||
|
http.Error(w, "Invalid redirect URL", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
oidcConfig := h.proxyService.GetOIDCConfig()
|
||||||
|
|
||||||
|
provider, err := oidc.NewProvider(r.Context(), oidcConfig.Issuer)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Error("Failed to create OIDC provider")
|
||||||
|
http.Error(w, "Failed to create OIDC provider", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := (&oauth2.Config{
|
||||||
|
ClientID: oidcConfig.ClientID,
|
||||||
|
Endpoint: provider.Endpoint(),
|
||||||
|
RedirectURL: oidcConfig.CallbackURL,
|
||||||
|
}).Exchange(r.Context(), r.URL.Query().Get("code"), oauth2.VerifierOption(codeVerifier))
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Error("Failed to exchange code for token")
|
||||||
|
http.Error(w, "Failed to exchange code for token", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userID := extractUserIDFromToken(r.Context(), provider, oidcConfig, token)
|
||||||
|
if userID == "" {
|
||||||
|
log.Error("Failed to extract user ID from OIDC token")
|
||||||
|
http.Error(w, "Failed to validate token", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group validation is performed by the proxy via ValidateSession gRPC call.
|
||||||
|
// This allows the proxy to show 403 pages directly without redirect dance.
|
||||||
|
|
||||||
|
sessionToken, err := h.proxyService.GenerateSessionToken(r.Context(), redirectURL.Hostname(), userID, auth.MethodOIDC)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Error("Failed to create session token")
|
||||||
|
redirectURL.Scheme = "https"
|
||||||
|
query := redirectURL.Query()
|
||||||
|
query.Set("error", "access_denied")
|
||||||
|
query.Set("error_description", "Service configuration error")
|
||||||
|
redirectURL.RawQuery = query.Encode()
|
||||||
|
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURL.Scheme = "https"
|
||||||
|
|
||||||
|
query := redirectURL.Query()
|
||||||
|
query.Set("session_token", sessionToken)
|
||||||
|
redirectURL.RawQuery = query.Encode()
|
||||||
|
|
||||||
|
log.WithField("redirect", redirectURL.Host).Debug("OAuth callback: redirecting user with session token")
|
||||||
|
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractUserIDFromToken(ctx context.Context, provider *oidc.Provider, config nbgrpc.ProxyOIDCConfig, token *oauth2.Token) string {
|
||||||
|
rawIDToken, ok := token.Extra("id_token").(string)
|
||||||
|
if !ok {
|
||||||
|
log.Warn("No id_token in OIDC response")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
verifier := provider.Verifier(&oidc.Config{
|
||||||
|
ClientID: config.ClientID,
|
||||||
|
})
|
||||||
|
|
||||||
|
idToken, err := verifier.Verify(ctx, rawIDToken)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Warn("Failed to verify ID token")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var claims struct {
|
||||||
|
Subject string `json:"sub"`
|
||||||
|
}
|
||||||
|
if err := idToken.Claims(&claims); err != nil {
|
||||||
|
log.WithError(err).Warn("Failed to extract claims from ID token")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return claims.Subject
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveClientIP extracts the real client IP from the request.
|
||||||
|
// When trustedProxies is non-empty and the direct peer is trusted,
|
||||||
|
// it walks X-Forwarded-For right-to-left skipping trusted IPs.
|
||||||
|
// Otherwise it returns RemoteAddr directly.
|
||||||
|
func (h *AuthCallbackHandler) resolveClientIP(r *http.Request) string {
|
||||||
|
remoteIP := extractHost(r.RemoteAddr)
|
||||||
|
|
||||||
|
if len(h.trustedProxies) == 0 || !isTrustedProxy(remoteIP, h.trustedProxies) {
|
||||||
|
return remoteIP
|
||||||
|
}
|
||||||
|
|
||||||
|
xff := r.Header.Get("X-Forwarded-For")
|
||||||
|
if xff == "" {
|
||||||
|
return remoteIP
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(xff, ",")
|
||||||
|
for i := len(parts) - 1; i >= 0; i-- {
|
||||||
|
ip := strings.TrimSpace(parts[i])
|
||||||
|
if ip == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !isTrustedProxy(ip, h.trustedProxies) {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// All IPs in XFF are trusted; return the leftmost as best guess.
|
||||||
|
if first := strings.TrimSpace(parts[0]); first != "" {
|
||||||
|
return first
|
||||||
|
}
|
||||||
|
return remoteIP
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractHost(remoteAddr string) string {
|
||||||
|
host, _, err := net.SplitHostPort(remoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
return remoteAddr
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
func isTrustedProxy(ipStr string, trusted []netip.Prefix) bool {
|
||||||
|
addr, err := netip.ParseAddr(ipStr)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, prefix := range trusted {
|
||||||
|
if prefix.Contains(addr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -0,0 +1,523 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/management/server/users"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// fakeOIDCServer creates a minimal OIDC provider for testing.
|
||||||
|
type fakeOIDCServer struct {
|
||||||
|
server *httptest.Server
|
||||||
|
issuer string
|
||||||
|
signingKey ed25519.PrivateKey
|
||||||
|
publicKey ed25519.PublicKey
|
||||||
|
keyID string
|
||||||
|
tokenSubject string
|
||||||
|
tokenExpiry time.Duration
|
||||||
|
failExchange bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFakeOIDCServer() *fakeOIDCServer {
|
||||||
|
pub, priv, _ := ed25519.GenerateKey(rand.Reader)
|
||||||
|
f := &fakeOIDCServer{
|
||||||
|
signingKey: priv,
|
||||||
|
publicKey: pub,
|
||||||
|
keyID: "test-key-1",
|
||||||
|
tokenExpiry: time.Hour,
|
||||||
|
}
|
||||||
|
f.server = httptest.NewServer(f)
|
||||||
|
f.issuer = f.server.URL
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeOIDCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/.well-known/openid-configuration":
|
||||||
|
f.handleDiscovery(w, r)
|
||||||
|
case "/token":
|
||||||
|
f.handleToken(w, r)
|
||||||
|
case "/keys":
|
||||||
|
f.handleJWKS(w, r)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeOIDCServer) handleDiscovery(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
discovery := map[string]interface{}{
|
||||||
|
"issuer": f.issuer,
|
||||||
|
"authorization_endpoint": f.issuer + "/auth",
|
||||||
|
"token_endpoint": f.issuer + "/token",
|
||||||
|
"jwks_uri": f.issuer + "/keys",
|
||||||
|
"response_types_supported": []string{
|
||||||
|
"code",
|
||||||
|
"id_token",
|
||||||
|
"token id_token",
|
||||||
|
},
|
||||||
|
"subject_types_supported": []string{"public"},
|
||||||
|
"id_token_signing_alg_values_supported": []string{"EdDSA"},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(discovery)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeOIDCServer) handleToken(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if f.failExchange {
|
||||||
|
http.Error(w, "invalid_grant", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
http.Error(w, "bad request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idToken := f.createIDToken()
|
||||||
|
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"access_token": "test-access-token",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"id_token": idToken,
|
||||||
|
"refresh_token": "test-refresh-token",
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeOIDCServer) createIDToken() string {
|
||||||
|
now := time.Now()
|
||||||
|
claims := jwt.MapClaims{
|
||||||
|
"iss": f.issuer,
|
||||||
|
"sub": f.tokenSubject,
|
||||||
|
"aud": "test-client-id",
|
||||||
|
"exp": now.Add(f.tokenExpiry).Unix(),
|
||||||
|
"iat": now.Unix(),
|
||||||
|
"nbf": now.Unix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
|
||||||
|
token.Header["kid"] = f.keyID
|
||||||
|
signed, _ := token.SignedString(f.signingKey)
|
||||||
|
return signed
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeOIDCServer) handleJWKS(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
jwks := map[string]interface{}{
|
||||||
|
"keys": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"kty": "OKP",
|
||||||
|
"crv": "Ed25519",
|
||||||
|
"kid": f.keyID,
|
||||||
|
"x": base64.RawURLEncoding.EncodeToString(f.publicKey),
|
||||||
|
"use": "sig",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(jwks)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeOIDCServer) Close() {
|
||||||
|
f.server.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// testSetup contains all test dependencies.
|
||||||
|
type testSetup struct {
|
||||||
|
store store.Store
|
||||||
|
oidcServer *fakeOIDCServer
|
||||||
|
proxyService *nbgrpc.ProxyServiceServer
|
||||||
|
handler *AuthCallbackHandler
|
||||||
|
router *mux.Router
|
||||||
|
cleanup func()
|
||||||
|
}
|
||||||
|
|
||||||
|
// testAccessLogManager is a minimal mock for accesslogs.Manager.
|
||||||
|
type testAccessLogManager struct{}
|
||||||
|
|
||||||
|
func (m *testAccessLogManager) SaveAccessLog(_ context.Context, _ *accesslogs.AccessLogEntry) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string, _ *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
|
||||||
|
return nil, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupAuthCallbackTest(t *testing.T) *testSetup {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
createTestAccountsAndUsers(t, ctx, testStore)
|
||||||
|
createTestReverseProxies(t, ctx, testStore)
|
||||||
|
|
||||||
|
oidcServer := newFakeOIDCServer()
|
||||||
|
|
||||||
|
tokenStore := nbgrpc.NewOneTimeTokenStore(time.Minute)
|
||||||
|
|
||||||
|
usersManager := users.NewManager(testStore)
|
||||||
|
|
||||||
|
oidcConfig := nbgrpc.ProxyOIDCConfig{
|
||||||
|
Issuer: oidcServer.issuer,
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
CallbackURL: "https://management.example.com/reverse-proxy/callback",
|
||||||
|
HMACKey: []byte("test-hmac-key-for-state-signing"),
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyService := nbgrpc.NewProxyServiceServer(
|
||||||
|
&testAccessLogManager{},
|
||||||
|
tokenStore,
|
||||||
|
oidcConfig,
|
||||||
|
nil,
|
||||||
|
usersManager,
|
||||||
|
)
|
||||||
|
|
||||||
|
proxyService.SetProxyManager(&testServiceManager{store: testStore})
|
||||||
|
|
||||||
|
handler := NewAuthCallbackHandler(proxyService, nil)
|
||||||
|
|
||||||
|
router := mux.NewRouter()
|
||||||
|
handler.RegisterEndpoints(router)
|
||||||
|
|
||||||
|
return &testSetup{
|
||||||
|
store: testStore,
|
||||||
|
oidcServer: oidcServer,
|
||||||
|
proxyService: proxyService,
|
||||||
|
handler: handler,
|
||||||
|
router: router,
|
||||||
|
cleanup: func() {
|
||||||
|
cleanup()
|
||||||
|
oidcServer.Close()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store.Store) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pubKey := base64.StdEncoding.EncodeToString(pub)
|
||||||
|
privKey := base64.StdEncoding.EncodeToString(priv)
|
||||||
|
|
||||||
|
testProxy := &reverseproxy.Service{
|
||||||
|
ID: "testProxyId",
|
||||||
|
AccountID: "testAccountId",
|
||||||
|
Name: "Test Proxy",
|
||||||
|
Domain: "test-proxy.example.com",
|
||||||
|
Targets: []*reverseproxy.Target{{
|
||||||
|
Path: strPtr("/"),
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 8080,
|
||||||
|
Protocol: "http",
|
||||||
|
TargetId: "peer1",
|
||||||
|
TargetType: "peer",
|
||||||
|
Enabled: true,
|
||||||
|
}},
|
||||||
|
Enabled: true,
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
DistributionGroups: []string{"allowedGroupId"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SessionPrivateKey: privKey,
|
||||||
|
SessionPublicKey: pubKey,
|
||||||
|
}
|
||||||
|
require.NoError(t, testStore.CreateService(ctx, testProxy))
|
||||||
|
|
||||||
|
restrictedProxy := &reverseproxy.Service{
|
||||||
|
ID: "restrictedProxyId",
|
||||||
|
AccountID: "testAccountId",
|
||||||
|
Name: "Restricted Proxy",
|
||||||
|
Domain: "restricted-proxy.example.com",
|
||||||
|
Targets: []*reverseproxy.Target{{
|
||||||
|
Path: strPtr("/"),
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 8080,
|
||||||
|
Protocol: "http",
|
||||||
|
TargetId: "peer1",
|
||||||
|
TargetType: "peer",
|
||||||
|
Enabled: true,
|
||||||
|
}},
|
||||||
|
Enabled: true,
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
DistributionGroups: []string{"restrictedGroupId"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SessionPrivateKey: privKey,
|
||||||
|
SessionPublicKey: pubKey,
|
||||||
|
}
|
||||||
|
require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
|
||||||
|
|
||||||
|
noAuthProxy := &reverseproxy.Service{
|
||||||
|
ID: "noAuthProxyId",
|
||||||
|
AccountID: "testAccountId",
|
||||||
|
Name: "No Auth Proxy",
|
||||||
|
Domain: "no-auth-proxy.example.com",
|
||||||
|
Targets: []*reverseproxy.Target{{
|
||||||
|
Path: strPtr("/"),
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 8080,
|
||||||
|
Protocol: "http",
|
||||||
|
TargetId: "peer1",
|
||||||
|
TargetType: "peer",
|
||||||
|
Enabled: true,
|
||||||
|
}},
|
||||||
|
Enabled: true,
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||||
|
Enabled: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SessionPrivateKey: privKey,
|
||||||
|
SessionPublicKey: pubKey,
|
||||||
|
}
|
||||||
|
require.NoError(t, testStore.CreateService(ctx, noAuthProxy))
|
||||||
|
}
|
||||||
|
|
||||||
|
func strPtr(s string) *string {
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestAccountsAndUsers(t *testing.T, ctx context.Context, testStore store.Store) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
testAccount := &types.Account{
|
||||||
|
Id: "testAccountId",
|
||||||
|
Domain: "test.com",
|
||||||
|
DomainCategory: "private",
|
||||||
|
IsDomainPrimaryAccount: true,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
require.NoError(t, testStore.SaveAccount(ctx, testAccount))
|
||||||
|
|
||||||
|
allowedGroup := &types.Group{
|
||||||
|
ID: "allowedGroupId",
|
||||||
|
AccountID: "testAccountId",
|
||||||
|
Name: "Allowed Group",
|
||||||
|
Issued: "api",
|
||||||
|
}
|
||||||
|
require.NoError(t, testStore.CreateGroup(ctx, allowedGroup))
|
||||||
|
|
||||||
|
allowedUser := &types.User{
|
||||||
|
Id: "allowedUserId",
|
||||||
|
AccountID: "testAccountId",
|
||||||
|
Role: types.UserRoleUser,
|
||||||
|
AutoGroups: []string{"allowedGroupId"},
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
Issued: "api",
|
||||||
|
}
|
||||||
|
require.NoError(t, testStore.SaveUser(ctx, allowedUser))
|
||||||
|
}
|
||||||
|
|
||||||
|
// testServiceManager is a minimal implementation for testing.
|
||||||
|
type testServiceManager struct {
|
||||||
|
store store.Store
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) DeleteService(_ context.Context, _, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) ReloadService(_ context.Context, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
||||||
|
return m.store.GetServices(ctx, store.LockingStrengthNone)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
|
||||||
|
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||||
|
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
resp, err := ps.GetOIDCURL(context.Background(), &proto.GetOIDCURLRequest{
|
||||||
|
RedirectUrl: redirectURL,
|
||||||
|
AccountId: "testAccountId",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
parsedURL, err := url.Parse(resp.Url)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return parsedURL.Query().Get("state")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthCallback_UserAllowedToLogin(t *testing.T) {
|
||||||
|
setup := setupAuthCallbackTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
setup.oidcServer.tokenSubject = "allowedUserId"
|
||||||
|
|
||||||
|
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/dashboard")
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
setup.router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusFound, rec.Code)
|
||||||
|
|
||||||
|
location := rec.Header().Get("Location")
|
||||||
|
require.NotEmpty(t, location)
|
||||||
|
|
||||||
|
parsedLocation, err := url.Parse(location)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, "test-proxy.example.com", parsedLocation.Host)
|
||||||
|
require.NotEmpty(t, parsedLocation.Query().Get("session_token"), "Should include session token")
|
||||||
|
require.Empty(t, parsedLocation.Query().Get("error"), "Should not have error parameter")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthCallback_ProxyNotFound(t *testing.T) {
|
||||||
|
setup := setupAuthCallbackTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
setup.oidcServer.tokenSubject = "allowedUserId"
|
||||||
|
|
||||||
|
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
|
||||||
|
|
||||||
|
require.NoError(t, setup.store.DeleteService(context.Background(), "testAccountId", "testProxyId"))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
setup.router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusFound, rec.Code)
|
||||||
|
|
||||||
|
location := rec.Header().Get("Location")
|
||||||
|
parsedLocation, err := url.Parse(location)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, "access_denied", parsedLocation.Query().Get("error"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthCallback_InvalidToken(t *testing.T) {
|
||||||
|
setup := setupAuthCallbackTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
setup.oidcServer.failExchange = true
|
||||||
|
|
||||||
|
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=invalid-code&state="+url.QueryEscape(state), nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
setup.router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||||
|
require.Contains(t, rec.Body.String(), "Failed to exchange code")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthCallback_ExpiredToken(t *testing.T) {
|
||||||
|
setup := setupAuthCallbackTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
setup.oidcServer.tokenSubject = "allowedUserId"
|
||||||
|
setup.oidcServer.tokenExpiry = -time.Hour
|
||||||
|
|
||||||
|
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
setup.router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||||
|
require.Contains(t, rec.Body.String(), "Failed to validate token")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthCallback_InvalidState(t *testing.T) {
|
||||||
|
setup := setupAuthCallbackTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state=invalid-state", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
setup.router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
require.Contains(t, rec.Body.String(), "Invalid state")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthCallback_MissingState(t *testing.T) {
|
||||||
|
setup := setupAuthCallbackTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
setup.router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
}
|
||||||
185
management/server/http/handlers/proxy/auth_test.go
Normal file
185
management/server/http/handlers/proxy/auth_test.go
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthCallbackHandler_RateLimiting(t *testing.T) {
|
||||||
|
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}, nil)
|
||||||
|
require.NotNil(t, handler.rateLimiter, "Rate limiter should be initialized")
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/callback?state=test&code=test", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.100:12345"
|
||||||
|
|
||||||
|
t.Run("allows requests under limit", func(t *testing.T) {
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
allowed := handler.rateLimiter.Allow("192.168.1.100")
|
||||||
|
assert.True(t, allowed, "Request %d should be allowed", i+1)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("blocks requests over limit", func(t *testing.T) {
|
||||||
|
handler.rateLimiter.Reset("192.168.1.200")
|
||||||
|
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
handler.rateLimiter.Allow("192.168.1.200")
|
||||||
|
}
|
||||||
|
|
||||||
|
allowed := handler.rateLimiter.Allow("192.168.1.200")
|
||||||
|
assert.False(t, allowed, "Request over limit should be blocked")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("different IPs have separate limits", func(t *testing.T) {
|
||||||
|
ip1 := "192.168.1.201"
|
||||||
|
ip2 := "192.168.1.202"
|
||||||
|
|
||||||
|
handler.rateLimiter.Reset(ip1)
|
||||||
|
handler.rateLimiter.Reset(ip2)
|
||||||
|
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
handler.rateLimiter.Allow(ip1)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.False(t, handler.rateLimiter.Allow(ip1), "IP1 should be blocked")
|
||||||
|
|
||||||
|
assert.True(t, handler.rateLimiter.Allow(ip2), "IP2 should be allowed")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthCallbackHandler_RateLimitInHandleCallback(t *testing.T) {
|
||||||
|
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}, nil)
|
||||||
|
testIP := "10.0.0.50"
|
||||||
|
|
||||||
|
handler.rateLimiter.Reset(testIP)
|
||||||
|
|
||||||
|
t.Run("returns 429 when rate limited", func(t *testing.T) {
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
handler.rateLimiter.Allow(testIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/callback?state=test&code=test", nil)
|
||||||
|
req.RemoteAddr = testIP + ":12345"
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.handleCallback(rr, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Should return 429 status code")
|
||||||
|
assert.Contains(t, rr.Body.String(), "Too many requests", "Should contain rate limit message")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveClientIP(t *testing.T) {
|
||||||
|
trusted := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
|
netip.MustParsePrefix("172.16.0.0/12"),
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
remoteAddr string
|
||||||
|
xForwardedFor string
|
||||||
|
trustedProxy []netip.Prefix
|
||||||
|
expectedIP string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no trusted proxies returns RemoteAddr",
|
||||||
|
remoteAddr: "203.0.113.50:9999",
|
||||||
|
xForwardedFor: "1.2.3.4",
|
||||||
|
trustedProxy: nil,
|
||||||
|
expectedIP: "203.0.113.50",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "untrusted RemoteAddr ignores XFF",
|
||||||
|
remoteAddr: "203.0.113.50:9999",
|
||||||
|
xForwardedFor: "1.2.3.4, 10.0.0.1",
|
||||||
|
trustedProxy: trusted,
|
||||||
|
expectedIP: "203.0.113.50",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trusted RemoteAddr with single client in XFF",
|
||||||
|
remoteAddr: "10.0.0.1:5000",
|
||||||
|
xForwardedFor: "203.0.113.50",
|
||||||
|
trustedProxy: trusted,
|
||||||
|
expectedIP: "203.0.113.50",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trusted RemoteAddr walks past trusted entries in XFF",
|
||||||
|
remoteAddr: "10.0.0.1:5000",
|
||||||
|
xForwardedFor: "203.0.113.50, 10.0.0.2, 172.16.0.5",
|
||||||
|
trustedProxy: trusted,
|
||||||
|
expectedIP: "203.0.113.50",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trusted RemoteAddr with empty XFF falls back to RemoteAddr",
|
||||||
|
remoteAddr: "10.0.0.1:5000",
|
||||||
|
trustedProxy: trusted,
|
||||||
|
expectedIP: "10.0.0.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all XFF IPs trusted returns leftmost",
|
||||||
|
remoteAddr: "10.0.0.1:5000",
|
||||||
|
xForwardedFor: "10.0.0.2, 172.16.0.1, 10.0.0.3",
|
||||||
|
trustedProxy: trusted,
|
||||||
|
expectedIP: "10.0.0.2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "XFF with whitespace",
|
||||||
|
remoteAddr: "10.0.0.1:5000",
|
||||||
|
xForwardedFor: " 203.0.113.50 , 10.0.0.2 ",
|
||||||
|
trustedProxy: trusted,
|
||||||
|
expectedIP: "203.0.113.50",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multi-hop with mixed trust",
|
||||||
|
remoteAddr: "10.0.0.1:5000",
|
||||||
|
xForwardedFor: "8.8.8.8, 203.0.113.50, 172.16.0.1",
|
||||||
|
trustedProxy: trusted,
|
||||||
|
expectedIP: "203.0.113.50",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RemoteAddr without port",
|
||||||
|
remoteAddr: "192.168.1.100",
|
||||||
|
expectedIP: "192.168.1.100",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}, tt.trustedProxy)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
req.RemoteAddr = tt.remoteAddr
|
||||||
|
if tt.xForwardedFor != "" {
|
||||||
|
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
||||||
|
}
|
||||||
|
|
||||||
|
ip := handler.resolveClientIP(req)
|
||||||
|
assert.Equal(t, tt.expectedIP, ip)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthCallbackHandler_RateLimiterConfiguration(t *testing.T) {
|
||||||
|
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}, nil)
|
||||||
|
|
||||||
|
require.NotNil(t, handler.rateLimiter, "Rate limiter should be initialized")
|
||||||
|
|
||||||
|
testIP := "192.168.1.250"
|
||||||
|
handler.rateLimiter.Reset(testIP)
|
||||||
|
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
allowed := handler.rateLimiter.Allow(testIP)
|
||||||
|
assert.True(t, allowed, "Should allow request %d within burst limit", i+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
allowed := handler.rateLimiter.Allow(testIP)
|
||||||
|
assert.False(t, allowed, "Should block request that exceeds burst limit")
|
||||||
|
}
|
||||||
@@ -10,6 +10,10 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||||
|
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
|
||||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||||
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
|
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
|
||||||
@@ -86,6 +90,14 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
|||||||
t.Fatalf("Failed to create manager: %v", err)
|
t.Fatalf("Failed to create manager: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil)
|
||||||
|
proxyTokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute)
|
||||||
|
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager)
|
||||||
|
domainManager := manager.NewManager(store, proxyServiceServer, permissionsManager)
|
||||||
|
reverseProxyManager := reverseproxymanager.NewManager(store, am, permissionsManager, proxyServiceServer, domainManager)
|
||||||
|
proxyServiceServer.SetProxyManager(reverseProxyManager)
|
||||||
|
am.SetServiceManager(reverseProxyManager)
|
||||||
|
|
||||||
// @note this is required so that PAT's validate from store, but JWT's are mocked
|
// @note this is required so that PAT's validate from store, but JWT's are mocked
|
||||||
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
|
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
|
||||||
authManagerMock := &serverauth.MockManager{
|
authManagerMock := &serverauth.MockManager{
|
||||||
@@ -102,7 +114,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
|||||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||||
|
|
||||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil)
|
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, reverseProxyManager, nil, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create API handler: %v", err)
|
t.Fatalf("Failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics)
|
|||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
httpTransport.MaxIdleConns = 5
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Timeout: idpTimeout(),
|
Timeout: idpTimeout(),
|
||||||
Transport: httpTransport,
|
Transport: httpTransport,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ func NewAuthentikManager(config AuthentikClientConfig, appMetrics telemetry.AppM
|
|||||||
Timeout: idpTimeout(),
|
Timeout: idpTimeout(),
|
||||||
Transport: httpTransport,
|
Transport: httpTransport,
|
||||||
}
|
}
|
||||||
|
|
||||||
helper := JsonParser{}
|
helper := JsonParser{}
|
||||||
|
|
||||||
if config.ClientID == "" {
|
if config.ClientID == "" {
|
||||||
|
|||||||
@@ -57,11 +57,11 @@ func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics)
|
|||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
httpTransport.MaxIdleConns = 5
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Timeout: idpTimeout(),
|
Timeout: idpTimeout(),
|
||||||
Transport: httpTransport,
|
Transport: httpTransport,
|
||||||
}
|
}
|
||||||
|
|
||||||
helper := JsonParser{}
|
helper := JsonParser{}
|
||||||
|
|
||||||
if config.ClientID == "" {
|
if config.ClientID == "" {
|
||||||
|
|||||||
@@ -91,6 +91,12 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
|||||||
cliRedirectURIs = append(cliRedirectURIs, "/device/callback")
|
cliRedirectURIs = append(cliRedirectURIs, "/device/callback")
|
||||||
cliRedirectURIs = append(cliRedirectURIs, c.Issuer+"/device/callback")
|
cliRedirectURIs = append(cliRedirectURIs, c.Issuer+"/device/callback")
|
||||||
|
|
||||||
|
// Build dashboard redirect URIs including the OAuth callback for proxy authentication
|
||||||
|
dashboardRedirectURIs := c.DashboardRedirectURIs
|
||||||
|
baseURL := strings.TrimSuffix(c.Issuer, "/oauth2")
|
||||||
|
// todo: resolve import cycle
|
||||||
|
dashboardRedirectURIs = append(dashboardRedirectURIs, baseURL+"/api/reverse-proxy/callback")
|
||||||
|
|
||||||
cfg := &dex.YAMLConfig{
|
cfg := &dex.YAMLConfig{
|
||||||
Issuer: c.Issuer,
|
Issuer: c.Issuer,
|
||||||
Storage: dex.Storage{
|
Storage: dex.Storage{
|
||||||
@@ -118,7 +124,7 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
|||||||
ID: staticClientDashboard,
|
ID: staticClientDashboard,
|
||||||
Name: "NetBird Dashboard",
|
Name: "NetBird Dashboard",
|
||||||
Public: true,
|
Public: true,
|
||||||
RedirectURIs: c.DashboardRedirectURIs,
|
RedirectURIs: dashboardRedirectURIs,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: staticClientCLI,
|
ID: staticClientCLI,
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClient
|
|||||||
Timeout: idpTimeout(),
|
Timeout: idpTimeout(),
|
||||||
Transport: httpTransport,
|
Transport: httpTransport,
|
||||||
}
|
}
|
||||||
|
|
||||||
helper := JsonParser{}
|
helper := JsonParser{}
|
||||||
|
|
||||||
if config.CustomerID == "" {
|
if config.CustomerID == "" {
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMet
|
|||||||
Timeout: idpTimeout(),
|
Timeout: idpTimeout(),
|
||||||
Transport: httpTransport,
|
Transport: httpTransport,
|
||||||
}
|
}
|
||||||
|
|
||||||
helper := JsonParser{}
|
helper := JsonParser{}
|
||||||
|
|
||||||
if config.ClientID == "" {
|
if config.ClientID == "" {
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMet
|
|||||||
Timeout: idpTimeout(),
|
Timeout: idpTimeout(),
|
||||||
Transport: httpTransport,
|
Transport: httpTransport,
|
||||||
}
|
}
|
||||||
|
|
||||||
helper := JsonParser{}
|
helper := JsonParser{}
|
||||||
|
|
||||||
if config.ManagementEndpoint == "" {
|
if config.ManagementEndpoint == "" {
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ const (
|
|||||||
// Provides the env variable name for use with idpTimeout function
|
// Provides the env variable name for use with idpTimeout function
|
||||||
idpTimeoutEnv = "NB_IDP_TIMEOUT"
|
idpTimeoutEnv = "NB_IDP_TIMEOUT"
|
||||||
// Sets the defaultTimeout to 10s.
|
// Sets the defaultTimeout to 10s.
|
||||||
defaultTimeout = 10 * time.Second
|
defaultTimeout = 10 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// idpTimeout returns a timeout value for the IDP
|
// idpTimeout returns a timeout value for the IDP
|
||||||
|
|||||||
@@ -167,7 +167,7 @@ func NewZitadelManager(config ZitadelClientConfig, appMetrics telemetry.AppMetri
|
|||||||
Timeout: idpTimeout(),
|
Timeout: idpTimeout(),
|
||||||
Transport: httpTransport,
|
Transport: httpTransport,
|
||||||
}
|
}
|
||||||
|
|
||||||
helper := JsonParser{}
|
helper := JsonParser{}
|
||||||
|
|
||||||
hasPAT := config.PAT != ""
|
hasPAT := config.PAT != ""
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
@@ -147,6 +148,10 @@ type MockAccountManager struct {
|
|||||||
DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
|
DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *MockAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
|
||||||
|
// Mock implementation - no-op
|
||||||
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error {
|
func (am *MockAccountManager) CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error {
|
||||||
if am.CreatePeerJobFunc != nil {
|
if am.CreatePeerJobFunc != nil {
|
||||||
return am.CreatePeerJobFunc(ctx, accountID, peerID, userID, job)
|
return am.CreatePeerJobFunc(ctx, accountID, peerID, userID, job)
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ func Test_GetAllNetworksReturnsNetworks(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(s)
|
permissionsManager := permissions.NewManager(s)
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
routerManager := routers.NewManagerMock()
|
routerManager := routers.NewManagerMock()
|
||||||
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
|
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
|
||||||
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
||||||
|
|
||||||
networks, err := manager.GetAllNetworks(ctx, accountID, userID)
|
networks, err := manager.GetAllNetworks(ctx, accountID, userID)
|
||||||
@@ -52,7 +52,7 @@ func Test_GetAllNetworksReturnsPermissionDenied(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(s)
|
permissionsManager := permissions.NewManager(s)
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
routerManager := routers.NewManagerMock()
|
routerManager := routers.NewManagerMock()
|
||||||
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
|
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
|
||||||
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
||||||
|
|
||||||
networks, err := manager.GetAllNetworks(ctx, accountID, userID)
|
networks, err := manager.GetAllNetworks(ctx, accountID, userID)
|
||||||
@@ -75,7 +75,7 @@ func Test_GetNetworkReturnsNetwork(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(s)
|
permissionsManager := permissions.NewManager(s)
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
routerManager := routers.NewManagerMock()
|
routerManager := routers.NewManagerMock()
|
||||||
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
|
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
|
||||||
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
||||||
|
|
||||||
networks, err := manager.GetNetwork(ctx, accountID, userID, networkID)
|
networks, err := manager.GetNetwork(ctx, accountID, userID, networkID)
|
||||||
@@ -98,7 +98,7 @@ func Test_GetNetworkReturnsPermissionDenied(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(s)
|
permissionsManager := permissions.NewManager(s)
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
routerManager := routers.NewManagerMock()
|
routerManager := routers.NewManagerMock()
|
||||||
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
|
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
|
||||||
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
||||||
|
|
||||||
network, err := manager.GetNetwork(ctx, accountID, userID, networkID)
|
network, err := manager.GetNetwork(ctx, accountID, userID, networkID)
|
||||||
@@ -123,7 +123,7 @@ func Test_CreateNetworkSuccessfully(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(s)
|
permissionsManager := permissions.NewManager(s)
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
routerManager := routers.NewManagerMock()
|
routerManager := routers.NewManagerMock()
|
||||||
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
|
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
|
||||||
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
||||||
|
|
||||||
createdNetwork, err := manager.CreateNetwork(ctx, userID, network)
|
createdNetwork, err := manager.CreateNetwork(ctx, userID, network)
|
||||||
@@ -148,7 +148,7 @@ func Test_CreateNetworkFailsWithPermissionDenied(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(s)
|
permissionsManager := permissions.NewManager(s)
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
routerManager := routers.NewManagerMock()
|
routerManager := routers.NewManagerMock()
|
||||||
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
|
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
|
||||||
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
||||||
|
|
||||||
createdNetwork, err := manager.CreateNetwork(ctx, userID, network)
|
createdNetwork, err := manager.CreateNetwork(ctx, userID, network)
|
||||||
@@ -171,7 +171,7 @@ func Test_DeleteNetworkSuccessfully(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(s)
|
permissionsManager := permissions.NewManager(s)
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
routerManager := routers.NewManagerMock()
|
routerManager := routers.NewManagerMock()
|
||||||
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
|
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
|
||||||
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
||||||
|
|
||||||
err = manager.DeleteNetwork(ctx, accountID, userID, networkID)
|
err = manager.DeleteNetwork(ctx, accountID, userID, networkID)
|
||||||
@@ -193,7 +193,7 @@ func Test_DeleteNetworkFailsWithPermissionDenied(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(s)
|
permissionsManager := permissions.NewManager(s)
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
routerManager := routers.NewManagerMock()
|
routerManager := routers.NewManagerMock()
|
||||||
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
|
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
|
||||||
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
||||||
|
|
||||||
err = manager.DeleteNetwork(ctx, accountID, userID, networkID)
|
err = manager.DeleteNetwork(ctx, accountID, userID, networkID)
|
||||||
@@ -218,7 +218,7 @@ func Test_UpdateNetworkSuccessfully(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(s)
|
permissionsManager := permissions.NewManager(s)
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
routerManager := routers.NewManagerMock()
|
routerManager := routers.NewManagerMock()
|
||||||
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
|
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
|
||||||
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
||||||
|
|
||||||
updatedNetwork, err := manager.UpdateNetwork(ctx, userID, network)
|
updatedNetwork, err := manager.UpdateNetwork(ctx, userID, network)
|
||||||
@@ -245,7 +245,7 @@ func Test_UpdateNetworkFailsWithPermissionDenied(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(s)
|
permissionsManager := permissions.NewManager(s)
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
routerManager := routers.NewManagerMock()
|
routerManager := routers.NewManagerMock()
|
||||||
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
|
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
|
||||||
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
||||||
|
|
||||||
updatedNetwork, err := manager.UpdateNetwork(ctx, userID, network)
|
updatedNetwork, err := manager.UpdateNetwork(ctx, userID, network)
|
||||||
|
|||||||
@@ -5,6 +5,9 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
@@ -30,21 +33,23 @@ type Manager interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type managerImpl struct {
|
type managerImpl struct {
|
||||||
store store.Store
|
store store.Store
|
||||||
permissionsManager permissions.Manager
|
permissionsManager permissions.Manager
|
||||||
groupsManager groups.Manager
|
groupsManager groups.Manager
|
||||||
accountManager account.Manager
|
accountManager account.Manager
|
||||||
|
reverseProxyManager reverseproxy.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockManager struct {
|
type mockManager struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager) Manager {
|
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager reverseproxy.Manager) Manager {
|
||||||
return &managerImpl{
|
return &managerImpl{
|
||||||
store: store,
|
store: store,
|
||||||
permissionsManager: permissionsManager,
|
permissionsManager: permissionsManager,
|
||||||
groupsManager: groupsManager,
|
groupsManager: groupsManager,
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
|
reverseProxyManager: reverseproxyManager,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -257,6 +262,14 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
|||||||
event()
|
event()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: optimize to only reload reverse proxies that are affected by the resource update instead of all of them
|
||||||
|
go func() {
|
||||||
|
err := m.reverseProxyManager.ReloadAllServicesForAccount(ctx, resource.AccountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to reload all proxies for account: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
|
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
|
||||||
|
|
||||||
return resource, nil
|
return resource, nil
|
||||||
@@ -309,6 +322,14 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
|
|||||||
return status.NewPermissionDeniedError()
|
return status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
serviceID, err := m.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, resourceID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to check if resource is used by service: %w", err)
|
||||||
|
}
|
||||||
|
if serviceID != "" {
|
||||||
|
return status.NewResourceInUseError(resourceID, serviceID)
|
||||||
|
}
|
||||||
|
|
||||||
var events []func()
|
var events []func()
|
||||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
events, err = m.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resourceID)
|
events, err = m.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resourceID)
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
"github.com/netbirdio/netbird/management/server/networks/resources/types"
|
"github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
@@ -28,7 +30,9 @@ func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
am := mock_server.MockAccountManager{}
|
am := mock_server.MockAccountManager{}
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
manager := NewManager(store, permissionsManager, groupsManager, &am)
|
ctrl := gomock.NewController(t)
|
||||||
|
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||||
|
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||||
|
|
||||||
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
|
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -49,7 +53,9 @@ func Test_GetAllResourcesInNetworkReturnsPermissionDenied(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
am := mock_server.MockAccountManager{}
|
am := mock_server.MockAccountManager{}
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
manager := NewManager(store, permissionsManager, groupsManager, &am)
|
ctrl := gomock.NewController(t)
|
||||||
|
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||||
|
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||||
|
|
||||||
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
|
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
@@ -69,7 +75,9 @@ func Test_GetAllResourcesInAccountReturnsResources(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
am := mock_server.MockAccountManager{}
|
am := mock_server.MockAccountManager{}
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
manager := NewManager(store, permissionsManager, groupsManager, &am)
|
ctrl := gomock.NewController(t)
|
||||||
|
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||||
|
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||||
|
|
||||||
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
|
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -89,7 +97,9 @@ func Test_GetAllResourcesInAccountReturnsPermissionDenied(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
am := mock_server.MockAccountManager{}
|
am := mock_server.MockAccountManager{}
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
manager := NewManager(store, permissionsManager, groupsManager, &am)
|
ctrl := gomock.NewController(t)
|
||||||
|
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||||
|
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||||
|
|
||||||
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
|
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
@@ -112,7 +122,9 @@ func Test_GetResourceInNetworkReturnsResources(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
am := mock_server.MockAccountManager{}
|
am := mock_server.MockAccountManager{}
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
manager := NewManager(store, permissionsManager, groupsManager, &am)
|
ctrl := gomock.NewController(t)
|
||||||
|
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||||
|
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||||
|
|
||||||
resource, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
|
resource, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -134,7 +146,9 @@ func Test_GetResourceInNetworkReturnsPermissionDenied(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
am := mock_server.MockAccountManager{}
|
am := mock_server.MockAccountManager{}
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
manager := NewManager(store, permissionsManager, groupsManager, &am)
|
ctrl := gomock.NewController(t)
|
||||||
|
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||||
|
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||||
|
|
||||||
resources, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
|
resources, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
@@ -161,7 +175,10 @@ func Test_CreateResourceSuccessfully(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
am := mock_server.MockAccountManager{}
|
am := mock_server.MockAccountManager{}
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
manager := NewManager(store, permissionsManager, groupsManager, &am)
|
ctrl := gomock.NewController(t)
|
||||||
|
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||||
|
reverseProxyManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), resource.AccountID).Return(nil).AnyTimes()
|
||||||
|
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||||
|
|
||||||
createdResource, err := manager.CreateResource(ctx, userID, resource)
|
createdResource, err := manager.CreateResource(ctx, userID, resource)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -187,7 +204,9 @@ func Test_CreateResourceFailsWithPermissionDenied(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
am := mock_server.MockAccountManager{}
|
am := mock_server.MockAccountManager{}
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
manager := NewManager(store, permissionsManager, groupsManager, &am)
|
ctrl := gomock.NewController(t)
|
||||||
|
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||||
|
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||||
|
|
||||||
createdResource, err := manager.CreateResource(ctx, userID, resource)
|
createdResource, err := manager.CreateResource(ctx, userID, resource)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
@@ -214,7 +233,9 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
am := mock_server.MockAccountManager{}
|
am := mock_server.MockAccountManager{}
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
manager := NewManager(store, permissionsManager, groupsManager, &am)
|
ctrl := gomock.NewController(t)
|
||||||
|
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||||
|
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||||
|
|
||||||
createdResource, err := manager.CreateResource(ctx, userID, resource)
|
createdResource, err := manager.CreateResource(ctx, userID, resource)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
@@ -240,7 +261,9 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
am := mock_server.MockAccountManager{}
|
am := mock_server.MockAccountManager{}
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
manager := NewManager(store, permissionsManager, groupsManager, &am)
|
ctrl := gomock.NewController(t)
|
||||||
|
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||||
|
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||||
|
|
||||||
createdResource, err := manager.CreateResource(ctx, userID, resource)
|
createdResource, err := manager.CreateResource(ctx, userID, resource)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
@@ -270,7 +293,10 @@ func Test_UpdateResourceSuccessfully(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
am := mock_server.MockAccountManager{}
|
am := mock_server.MockAccountManager{}
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
manager := NewManager(store, permissionsManager, groupsManager, &am)
|
ctrl := gomock.NewController(t)
|
||||||
|
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||||
|
reverseProxyManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), accountID).Return(nil).AnyTimes()
|
||||||
|
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||||
|
|
||||||
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
|
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -302,7 +328,9 @@ func Test_UpdateResourceFailsWithResourceNotFound(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
am := mock_server.MockAccountManager{}
|
am := mock_server.MockAccountManager{}
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
manager := NewManager(store, permissionsManager, groupsManager, &am)
|
ctrl := gomock.NewController(t)
|
||||||
|
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||||
|
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||||
|
|
||||||
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
|
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
@@ -332,7 +360,9 @@ func Test_UpdateResourceFailsWithNameInUse(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
am := mock_server.MockAccountManager{}
|
am := mock_server.MockAccountManager{}
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
manager := NewManager(store, permissionsManager, groupsManager, &am)
|
ctrl := gomock.NewController(t)
|
||||||
|
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||||
|
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||||
|
|
||||||
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
|
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
@@ -361,7 +391,9 @@ func Test_UpdateResourceFailsWithPermissionDenied(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
am := mock_server.MockAccountManager{}
|
am := mock_server.MockAccountManager{}
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
manager := NewManager(store, permissionsManager, groupsManager, &am)
|
ctrl := gomock.NewController(t)
|
||||||
|
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||||
|
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||||
|
|
||||||
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
|
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
@@ -383,7 +415,10 @@ func Test_DeleteResourceSuccessfully(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
am := mock_server.MockAccountManager{}
|
am := mock_server.MockAccountManager{}
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
manager := NewManager(store, permissionsManager, groupsManager, &am)
|
ctrl := gomock.NewController(t)
|
||||||
|
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||||
|
reverseProxyManager.EXPECT().GetServiceIDByTargetID(gomock.Any(), accountID, resourceID).Return("", nil).AnyTimes()
|
||||||
|
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||||
|
|
||||||
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
|
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -404,7 +439,9 @@ func Test_DeleteResourceFailsWithPermissionDenied(t *testing.T) {
|
|||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
am := mock_server.MockAccountManager{}
|
am := mock_server.MockAccountManager{}
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
manager := NewManager(store, permissionsManager, groupsManager, &am)
|
ctrl := gomock.NewController(t)
|
||||||
|
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||||
|
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||||
|
|
||||||
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
|
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|||||||
@@ -221,6 +221,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if peer.ProxyMeta.Embedded {
|
||||||
|
return fmt.Errorf("not allowed to update peer")
|
||||||
|
}
|
||||||
|
|
||||||
settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -489,6 +493,14 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
|||||||
var settings *types.Settings
|
var settings *types.Settings
|
||||||
var eventsToStore []func()
|
var eventsToStore []func()
|
||||||
|
|
||||||
|
serviceID, err := am.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, peerID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to check if resource is used by service: %w", err)
|
||||||
|
}
|
||||||
|
if serviceID != "" {
|
||||||
|
return status.NewPeerInUseError(peerID, serviceID)
|
||||||
|
}
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -549,6 +561,99 @@ func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID stri
|
|||||||
return account.Network.Copy(), err
|
return account.Network.Copy(), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type peerAddAuthConfig struct {
|
||||||
|
AccountID string
|
||||||
|
SetupKeyID string
|
||||||
|
SetupKeyName string
|
||||||
|
GroupsToAdd []string
|
||||||
|
AllowExtraDNSLabels bool
|
||||||
|
Ephemeral bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) processPeerAddAuth(ctx context.Context, accountID, userID, encodedHashedKey string, peer *nbpeer.Peer, temporary, addedByUser, addedBySetupKey bool, opEvent *activity.Event) (*peerAddAuthConfig, error) {
|
||||||
|
config := &peerAddAuthConfig{
|
||||||
|
AccountID: accountID,
|
||||||
|
Ephemeral: peer.Ephemeral,
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case addedByUser:
|
||||||
|
if err := am.handleUserAddedPeer(ctx, accountID, userID, temporary, opEvent, config); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
case addedBySetupKey:
|
||||||
|
if err := am.handleSetupKeyAddedPeer(ctx, encodedHashedKey, peer, opEvent, config); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if peer.ProxyMeta.Embedded {
|
||||||
|
log.WithContext(ctx).Debugf("adding peer for proxy embedded, accountID: %s", accountID)
|
||||||
|
} else {
|
||||||
|
log.WithContext(ctx).Warnf("adding peer without setup key or userID, accountID: %s", accountID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
opEvent.AccountID = config.AccountID
|
||||||
|
if temporary {
|
||||||
|
config.Ephemeral = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) handleUserAddedPeer(ctx context.Context, accountID, userID string, temporary bool, opEvent *activity.Event, config *peerAddAuthConfig) error {
|
||||||
|
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||||
|
if err != nil {
|
||||||
|
return status.Errorf(status.NotFound, "failed adding new peer: user not found")
|
||||||
|
}
|
||||||
|
if user.PendingApproval {
|
||||||
|
return status.Errorf(status.PermissionDenied, "user pending approval cannot add peers")
|
||||||
|
}
|
||||||
|
|
||||||
|
if temporary {
|
||||||
|
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Create)
|
||||||
|
if err != nil {
|
||||||
|
return status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !allowed {
|
||||||
|
return status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
config.AccountID = user.AccountID
|
||||||
|
config.GroupsToAdd = user.AutoGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
opEvent.InitiatorID = userID
|
||||||
|
opEvent.Activity = activity.PeerAddedByUser
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) handleSetupKeyAddedPeer(ctx context.Context, encodedHashedKey string, peer *nbpeer.Peer, opEvent *activity.Event, config *peerAddAuthConfig) error {
|
||||||
|
sk, err := am.Store.GetSetupKeyBySecret(ctx, store.LockingStrengthNone, encodedHashedKey)
|
||||||
|
if err != nil {
|
||||||
|
return status.Errorf(status.NotFound, "couldn't add peer: setup key is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !sk.IsValid() {
|
||||||
|
return status.Errorf(status.NotFound, "couldn't add peer: setup key is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 {
|
||||||
|
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels")
|
||||||
|
}
|
||||||
|
|
||||||
|
opEvent.InitiatorID = sk.Id
|
||||||
|
opEvent.Activity = activity.PeerAddedWithSetupKey
|
||||||
|
config.GroupsToAdd = sk.AutoGroups
|
||||||
|
config.Ephemeral = sk.Ephemeral
|
||||||
|
config.SetupKeyID = sk.Id
|
||||||
|
config.SetupKeyName = sk.Name
|
||||||
|
config.AllowExtraDNSLabels = sk.AllowExtraDNSLabels
|
||||||
|
config.AccountID = sk.AccountID
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// AddPeer adds a new peer to the Store.
|
// AddPeer adds a new peer to the Store.
|
||||||
// Each Account has a list of pre-authorized SetupKey and if no Account has a given key err with a code status.PermissionDenied
|
// Each Account has a list of pre-authorized SetupKey and if no Account has a given key err with a code status.PermissionDenied
|
||||||
// will be returned, meaning the setup key is invalid or not found.
|
// will be returned, meaning the setup key is invalid or not found.
|
||||||
@@ -557,7 +662,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID stri
|
|||||||
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
|
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
|
||||||
// The peer property is just a placeholder for the Peer properties to pass further
|
// The peer property is just a placeholder for the Peer properties to pass further
|
||||||
func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||||
if setupKey == "" && userID == "" {
|
if setupKey == "" && userID == "" && !peer.ProxyMeta.Embedded {
|
||||||
// no auth method provided => reject access
|
// no auth method provided => reject access
|
||||||
return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
|
return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
|
||||||
}
|
}
|
||||||
@@ -566,6 +671,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
|||||||
hashedKey := sha256.Sum256([]byte(upperKey))
|
hashedKey := sha256.Sum256([]byte(upperKey))
|
||||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||||
addedByUser := len(userID) > 0
|
addedByUser := len(userID) > 0
|
||||||
|
addedBySetupKey := len(setupKey) > 0
|
||||||
|
|
||||||
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
|
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
|
||||||
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
|
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
|
||||||
@@ -583,63 +689,12 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
|||||||
|
|
||||||
var newPeer *nbpeer.Peer
|
var newPeer *nbpeer.Peer
|
||||||
|
|
||||||
var setupKeyID string
|
peerAddConfig, err := am.processPeerAddAuth(ctx, accountID, userID, encodedHashedKey, peer, temporary, addedByUser, addedBySetupKey, opEvent)
|
||||||
var setupKeyName string
|
if err != nil {
|
||||||
var ephemeral bool
|
return nil, nil, nil, err
|
||||||
var groupsToAdd []string
|
|
||||||
var allowExtraDNSLabels bool
|
|
||||||
if addedByUser {
|
|
||||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: user not found")
|
|
||||||
}
|
|
||||||
if user.PendingApproval {
|
|
||||||
return nil, nil, nil, status.Errorf(status.PermissionDenied, "user pending approval cannot add peers")
|
|
||||||
}
|
|
||||||
if temporary {
|
|
||||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Create)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, status.NewPermissionValidationError(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !allowed {
|
|
||||||
return nil, nil, nil, status.NewPermissionDeniedError()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
accountID = user.AccountID
|
|
||||||
groupsToAdd = user.AutoGroups
|
|
||||||
}
|
|
||||||
opEvent.InitiatorID = userID
|
|
||||||
opEvent.Activity = activity.PeerAddedByUser
|
|
||||||
} else {
|
|
||||||
// Validate the setup key
|
|
||||||
sk, err := am.Store.GetSetupKeyBySecret(ctx, store.LockingStrengthNone, encodedHashedKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, status.Errorf(status.NotFound, "couldn't add peer: setup key is invalid")
|
|
||||||
}
|
|
||||||
|
|
||||||
// we will check key twice for early return
|
|
||||||
if !sk.IsValid() {
|
|
||||||
return nil, nil, nil, status.Errorf(status.NotFound, "couldn't add peer: setup key is invalid")
|
|
||||||
}
|
|
||||||
|
|
||||||
opEvent.InitiatorID = sk.Id
|
|
||||||
opEvent.Activity = activity.PeerAddedWithSetupKey
|
|
||||||
groupsToAdd = sk.AutoGroups
|
|
||||||
ephemeral = sk.Ephemeral
|
|
||||||
setupKeyID = sk.Id
|
|
||||||
setupKeyName = sk.Name
|
|
||||||
allowExtraDNSLabels = sk.AllowExtraDNSLabels
|
|
||||||
accountID = sk.AccountID
|
|
||||||
if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 {
|
|
||||||
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
opEvent.AccountID = accountID
|
|
||||||
|
|
||||||
if temporary {
|
|
||||||
ephemeral = true
|
|
||||||
}
|
}
|
||||||
|
accountID = peerAddConfig.AccountID
|
||||||
|
ephemeral := peerAddConfig.Ephemeral
|
||||||
|
|
||||||
if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" {
|
if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" {
|
||||||
if am.idpManager != nil {
|
if am.idpManager != nil {
|
||||||
@@ -669,10 +724,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
|||||||
CreatedAt: registrationTime,
|
CreatedAt: registrationTime,
|
||||||
LoginExpirationEnabled: addedByUser && !temporary,
|
LoginExpirationEnabled: addedByUser && !temporary,
|
||||||
Ephemeral: ephemeral,
|
Ephemeral: ephemeral,
|
||||||
|
ProxyMeta: peer.ProxyMeta,
|
||||||
Location: peer.Location,
|
Location: peer.Location,
|
||||||
InactivityExpirationEnabled: addedByUser && !temporary,
|
InactivityExpirationEnabled: addedByUser && !temporary,
|
||||||
ExtraDNSLabels: peer.ExtraDNSLabels,
|
ExtraDNSLabels: peer.ExtraDNSLabels,
|
||||||
AllowExtraDNSLabels: allowExtraDNSLabels,
|
AllowExtraDNSLabels: peerAddConfig.AllowExtraDNSLabels,
|
||||||
}
|
}
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -690,7 +746,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra, temporary)
|
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, peerAddConfig.GroupsToAdd, settings.Extra, temporary)
|
||||||
|
|
||||||
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -726,8 +782,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(groupsToAdd) > 0 {
|
if len(peerAddConfig.GroupsToAdd) > 0 {
|
||||||
for _, g := range groupsToAdd {
|
for _, g := range peerAddConfig.GroupsToAdd {
|
||||||
err = transaction.AddPeerToGroup(ctx, newPeer.AccountID, newPeer.ID, g)
|
err = transaction.AddPeerToGroup(ctx, newPeer.AccountID, newPeer.ID, g)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -735,17 +791,20 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID)
|
if !peer.ProxyMeta.Embedded {
|
||||||
if err != nil {
|
err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID)
|
||||||
return fmt.Errorf("failed adding peer to All group: %w", err)
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed adding peer to All group: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if addedByUser {
|
switch {
|
||||||
|
case addedByUser:
|
||||||
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
|
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
|
log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
|
||||||
}
|
}
|
||||||
} else {
|
case addedBySetupKey:
|
||||||
sk, err := transaction.GetSetupKeyBySecret(ctx, store.LockingStrengthUpdate, encodedHashedKey)
|
sk, err := transaction.GetSetupKeyBySecret(ctx, store.LockingStrengthUpdate, encodedHashedKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get setup key: %w", err)
|
return fmt.Errorf("failed to get setup key: %w", err)
|
||||||
@@ -756,7 +815,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
|||||||
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
|
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
|
err = transaction.IncrementSetupKeyUsage(ctx, peerAddConfig.SetupKeyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to increment setup key usage: %w", err)
|
return fmt.Errorf("failed to increment setup key usage: %w", err)
|
||||||
}
|
}
|
||||||
@@ -797,7 +856,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
|||||||
opEvent.TargetID = newPeer.ID
|
opEvent.TargetID = newPeer.ID
|
||||||
opEvent.Meta = newPeer.EventMeta(am.networkMapController.GetDNSDomain(settings))
|
opEvent.Meta = newPeer.EventMeta(am.networkMapController.GetDNSDomain(settings))
|
||||||
if !addedByUser {
|
if !addedByUser {
|
||||||
opEvent.Meta["setup_key_name"] = setupKeyName
|
opEvent.Meta["setup_key_name"] = peerAddConfig.SetupKeyName
|
||||||
}
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ type Peer struct {
|
|||||||
IP net.IP `gorm:"serializer:json"` // uniqueness index per accountID (check migrations)
|
IP net.IP `gorm:"serializer:json"` // uniqueness index per accountID (check migrations)
|
||||||
// Meta is a Peer system meta data
|
// Meta is a Peer system meta data
|
||||||
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||||
|
// ProxyMeta is metadata related to proxy peers
|
||||||
|
ProxyMeta ProxyMeta `gorm:"embedded;embeddedPrefix:proxy_meta_"`
|
||||||
// Name is peer's name (machine name)
|
// Name is peer's name (machine name)
|
||||||
Name string `gorm:"index"`
|
Name string `gorm:"index"`
|
||||||
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
|
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
|
||||||
@@ -48,6 +50,7 @@ type Peer struct {
|
|||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
// Indicate ephemeral peer attribute
|
// Indicate ephemeral peer attribute
|
||||||
Ephemeral bool `gorm:"index"`
|
Ephemeral bool `gorm:"index"`
|
||||||
|
|
||||||
// Geo location based on connection IP
|
// Geo location based on connection IP
|
||||||
Location Location `gorm:"embedded;embeddedPrefix:location_"`
|
Location Location `gorm:"embedded;embeddedPrefix:location_"`
|
||||||
|
|
||||||
@@ -57,6 +60,11 @@ type Peer struct {
|
|||||||
AllowExtraDNSLabels bool
|
AllowExtraDNSLabels bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ProxyMeta struct {
|
||||||
|
Embedded bool `gorm:"index"`
|
||||||
|
Cluster string `gorm:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
type PeerStatus struct { //nolint:revive
|
type PeerStatus struct { //nolint:revive
|
||||||
// LastSeen is the last time peer was connected to the management service
|
// LastSeen is the last time peer was connected to the management service
|
||||||
LastSeen time.Time
|
LastSeen time.Time
|
||||||
@@ -224,6 +232,7 @@ func (p *Peer) Copy() *Peer {
|
|||||||
LastLogin: p.LastLogin,
|
LastLogin: p.LastLogin,
|
||||||
CreatedAt: p.CreatedAt,
|
CreatedAt: p.CreatedAt,
|
||||||
Ephemeral: p.Ephemeral,
|
Ephemeral: p.Ephemeral,
|
||||||
|
ProxyMeta: p.ProxyMeta,
|
||||||
Location: p.Location,
|
Location: p.Location,
|
||||||
InactivityExpirationEnabled: p.InactivityExpirationEnabled,
|
InactivityExpirationEnabled: p.InactivityExpirationEnabled,
|
||||||
ExtraDNSLabels: slices.Clone(p.ExtraDNSLabels),
|
ExtraDNSLabels: slices.Clone(p.ExtraDNSLabels),
|
||||||
|
|||||||
@@ -2489,3 +2489,252 @@ func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) {
|
|||||||
_, _, _, err = manager.LoginPeer(context.Background(), login)
|
_, _, _, err = manager.LoginPeer(context.Background(), login)
|
||||||
require.NoError(t, err, "Regular user should be able to login peers")
|
require.NoError(t, err, "Regular user should be able to login peers")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandleUserAddedPeer(t *testing.T) {
|
||||||
|
manager, _, err := createManager(t)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
account := newAccountWithId(context.Background(), "test-account", "owner", "", "", "", false)
|
||||||
|
err = manager.Store.SaveAccount(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Run("regular user can add peer", func(t *testing.T) {
|
||||||
|
regularUser := types.NewRegularUser("regular-user-1", "", "")
|
||||||
|
regularUser.AccountID = account.Id
|
||||||
|
regularUser.AutoGroups = []string{"group1", "group2"}
|
||||||
|
err = manager.Store.SaveUser(context.Background(), regularUser)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
opEvent := &activity.Event{}
|
||||||
|
config := &peerAddAuthConfig{}
|
||||||
|
|
||||||
|
err = manager.handleUserAddedPeer(context.Background(), account.Id, regularUser.Id, false, opEvent, config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, account.Id, config.AccountID)
|
||||||
|
assert.Equal(t, regularUser.AutoGroups, config.GroupsToAdd)
|
||||||
|
assert.Equal(t, regularUser.Id, opEvent.InitiatorID)
|
||||||
|
assert.Equal(t, activity.PeerAddedByUser, opEvent.Activity)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("pending approval user cannot add peer", func(t *testing.T) {
|
||||||
|
pendingUser := types.NewRegularUser("pending-user", "", "")
|
||||||
|
pendingUser.AccountID = account.Id
|
||||||
|
pendingUser.PendingApproval = true
|
||||||
|
err = manager.Store.SaveUser(context.Background(), pendingUser)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
opEvent := &activity.Event{}
|
||||||
|
config := &peerAddAuthConfig{}
|
||||||
|
|
||||||
|
err = manager.handleUserAddedPeer(context.Background(), account.Id, pendingUser.Id, false, opEvent, config)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "user pending approval cannot add peers")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("user not found", func(t *testing.T) {
|
||||||
|
opEvent := &activity.Event{}
|
||||||
|
config := &peerAddAuthConfig{}
|
||||||
|
|
||||||
|
err = manager.handleUserAddedPeer(context.Background(), account.Id, "non-existent-user", false, opEvent, config)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "user not found")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("temporary peer requires permissions", func(t *testing.T) {
|
||||||
|
regularUser := types.NewRegularUser("regular-user-2", "", "")
|
||||||
|
regularUser.AccountID = account.Id
|
||||||
|
err = manager.Store.SaveUser(context.Background(), regularUser)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
opEvent := &activity.Event{}
|
||||||
|
config := &peerAddAuthConfig{}
|
||||||
|
|
||||||
|
// Should fail because user doesn't have permissions for temporary peers
|
||||||
|
err = manager.handleUserAddedPeer(context.Background(), account.Id, regularUser.Id, true, opEvent, config)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleSetupKeyAddedPeer(t *testing.T) {
|
||||||
|
manager, _, err := createManager(t)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
account := newAccountWithId(context.Background(), "test-account", "owner", "", "", "", false)
|
||||||
|
err = manager.Store.SaveAccount(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create admin user for setup key creation
|
||||||
|
adminUser := types.NewAdminUser("admin-user")
|
||||||
|
adminUser.AccountID = account.Id
|
||||||
|
err = manager.Store.SaveUser(context.Background(), adminUser)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Run("valid setup key", func(t *testing.T) {
|
||||||
|
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, []string{}, 0, adminUser.Id, false, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
upperKey := strings.ToUpper(setupKey.Key)
|
||||||
|
hashedKey := sha256.Sum256([]byte(upperKey))
|
||||||
|
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||||
|
|
||||||
|
opEvent := &activity.Event{}
|
||||||
|
config := &peerAddAuthConfig{}
|
||||||
|
peer := &nbpeer.Peer{ExtraDNSLabels: []string{}}
|
||||||
|
|
||||||
|
err = manager.handleSetupKeyAddedPeer(context.Background(), encodedHashedKey, peer, opEvent, config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, setupKey.Id, config.SetupKeyID)
|
||||||
|
assert.Equal(t, setupKey.Name, config.SetupKeyName)
|
||||||
|
assert.Equal(t, setupKey.AutoGroups, config.GroupsToAdd)
|
||||||
|
assert.Equal(t, setupKey.Ephemeral, config.Ephemeral)
|
||||||
|
assert.Equal(t, setupKey.Id, opEvent.InitiatorID)
|
||||||
|
assert.Equal(t, activity.PeerAddedWithSetupKey, opEvent.Activity)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid setup key", func(t *testing.T) {
|
||||||
|
invalidKey := "invalid-key"
|
||||||
|
hashedKey := sha256.Sum256([]byte(invalidKey))
|
||||||
|
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||||
|
|
||||||
|
opEvent := &activity.Event{}
|
||||||
|
config := &peerAddAuthConfig{}
|
||||||
|
peer := &nbpeer.Peer{}
|
||||||
|
|
||||||
|
err = manager.handleSetupKeyAddedPeer(context.Background(), encodedHashedKey, peer, opEvent, config)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "setup key is invalid")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("expired setup key", func(t *testing.T) {
|
||||||
|
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "expired-key", types.SetupKeyReusable, time.Millisecond, []string{}, 0, adminUser.Id, false, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Wait for key to expire
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
upperKey := strings.ToUpper(setupKey.Key)
|
||||||
|
hashedKey := sha256.Sum256([]byte(upperKey))
|
||||||
|
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||||
|
|
||||||
|
opEvent := &activity.Event{}
|
||||||
|
config := &peerAddAuthConfig{}
|
||||||
|
peer := &nbpeer.Peer{}
|
||||||
|
|
||||||
|
err = manager.handleSetupKeyAddedPeer(context.Background(), encodedHashedKey, peer, opEvent, config)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "setup key is invalid")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("extra DNS labels not allowed", func(t *testing.T) {
|
||||||
|
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "no-dns-key", types.SetupKeyReusable, time.Hour, []string{}, 0, adminUser.Id, false, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
upperKey := strings.ToUpper(setupKey.Key)
|
||||||
|
hashedKey := sha256.Sum256([]byte(upperKey))
|
||||||
|
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||||
|
|
||||||
|
opEvent := &activity.Event{}
|
||||||
|
config := &peerAddAuthConfig{}
|
||||||
|
peer := &nbpeer.Peer{ExtraDNSLabels: []string{"custom.label"}}
|
||||||
|
|
||||||
|
err = manager.handleSetupKeyAddedPeer(context.Background(), encodedHashedKey, peer, opEvent, config)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "doesn't allow extra DNS labels")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("extra DNS labels allowed", func(t *testing.T) {
|
||||||
|
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "dns-key", types.SetupKeyReusable, time.Hour, []string{}, 0, adminUser.Id, false, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
upperKey := strings.ToUpper(setupKey.Key)
|
||||||
|
hashedKey := sha256.Sum256([]byte(upperKey))
|
||||||
|
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||||
|
|
||||||
|
opEvent := &activity.Event{}
|
||||||
|
config := &peerAddAuthConfig{}
|
||||||
|
peer := &nbpeer.Peer{ExtraDNSLabels: []string{"custom.label"}}
|
||||||
|
|
||||||
|
err = manager.handleSetupKeyAddedPeer(context.Background(), encodedHashedKey, peer, opEvent, config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, config.AllowExtraDNSLabels)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessPeerAddAuth(t *testing.T) {
|
||||||
|
manager, _, err := createManager(t)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
account := newAccountWithId(context.Background(), "test-account", "owner", "", "", "", false)
|
||||||
|
err = manager.Store.SaveAccount(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
adminUser := types.NewAdminUser("admin")
|
||||||
|
adminUser.AccountID = account.Id
|
||||||
|
err = manager.Store.SaveUser(context.Background(), adminUser)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Run("user authentication flow", func(t *testing.T) {
|
||||||
|
regularUser := types.NewRegularUser("user-auth-test", "", "")
|
||||||
|
regularUser.AccountID = account.Id
|
||||||
|
regularUser.AutoGroups = []string{"group1"}
|
||||||
|
err = manager.Store.SaveUser(context.Background(), regularUser)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
opEvent := &activity.Event{Timestamp: time.Now()}
|
||||||
|
peer := &nbpeer.Peer{Ephemeral: false}
|
||||||
|
|
||||||
|
config, err := manager.processPeerAddAuth(context.Background(), account.Id, regularUser.Id, "", peer, false, true, false, opEvent)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, account.Id, config.AccountID)
|
||||||
|
assert.False(t, config.Ephemeral)
|
||||||
|
assert.Equal(t, regularUser.AutoGroups, config.GroupsToAdd)
|
||||||
|
assert.Equal(t, account.Id, opEvent.AccountID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("setup key authentication flow", func(t *testing.T) {
|
||||||
|
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "auth-test-key", types.SetupKeyReusable, time.Hour, []string{}, 0, adminUser.Id, true, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
upperKey := strings.ToUpper(setupKey.Key)
|
||||||
|
hashedKey := sha256.Sum256([]byte(upperKey))
|
||||||
|
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||||
|
|
||||||
|
opEvent := &activity.Event{Timestamp: time.Now()}
|
||||||
|
peer := &nbpeer.Peer{Ephemeral: false}
|
||||||
|
|
||||||
|
config, err := manager.processPeerAddAuth(context.Background(), account.Id, "", encodedHashedKey, peer, false, false, true, opEvent)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, account.Id, config.AccountID)
|
||||||
|
assert.True(t, config.Ephemeral) // setupKey.Ephemeral is true
|
||||||
|
assert.Equal(t, setupKey.AutoGroups, config.GroupsToAdd)
|
||||||
|
assert.Equal(t, account.Id, opEvent.AccountID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("temporary flag overrides ephemeral", func(t *testing.T) {
|
||||||
|
regularUser := types.NewRegularUser("temp-user", "", "")
|
||||||
|
regularUser.AccountID = account.Id
|
||||||
|
err = manager.Store.SaveUser(context.Background(), regularUser)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
opEvent := &activity.Event{Timestamp: time.Now()}
|
||||||
|
peer := &nbpeer.Peer{Ephemeral: false}
|
||||||
|
|
||||||
|
config, err := manager.processPeerAddAuth(context.Background(), account.Id, regularUser.Id, "", peer, true, true, false, opEvent)
|
||||||
|
require.Error(t, err) // Will fail permission check but that's expected
|
||||||
|
_ = config // avoid unused warning
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("proxy embedded peer (no auth)", func(t *testing.T) {
|
||||||
|
opEvent := &activity.Event{Timestamp: time.Now()}
|
||||||
|
peer := &nbpeer.Peer{
|
||||||
|
Ephemeral: false,
|
||||||
|
ProxyMeta: nbpeer.ProxyMeta{Embedded: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := manager.processPeerAddAuth(context.Background(), account.Id, "", "", peer, false, false, false, opEvent)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, account.Id, config.AccountID)
|
||||||
|
assert.False(t, config.Ephemeral)
|
||||||
|
assert.Empty(t, config.GroupsToAdd)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,37 +3,39 @@ package modules
|
|||||||
type Module string
|
type Module string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
Networks Module = "networks"
|
Networks Module = "networks"
|
||||||
Peers Module = "peers"
|
Peers Module = "peers"
|
||||||
RemoteJobs Module = "remote_jobs"
|
RemoteJobs Module = "remote_jobs"
|
||||||
Groups Module = "groups"
|
Groups Module = "groups"
|
||||||
Settings Module = "settings"
|
Settings Module = "settings"
|
||||||
Accounts Module = "accounts"
|
Accounts Module = "accounts"
|
||||||
Dns Module = "dns"
|
Dns Module = "dns"
|
||||||
Nameservers Module = "nameservers"
|
Nameservers Module = "nameservers"
|
||||||
Events Module = "events"
|
Events Module = "events"
|
||||||
Policies Module = "policies"
|
Policies Module = "policies"
|
||||||
Routes Module = "routes"
|
Routes Module = "routes"
|
||||||
Users Module = "users"
|
Users Module = "users"
|
||||||
SetupKeys Module = "setup_keys"
|
SetupKeys Module = "setup_keys"
|
||||||
Pats Module = "pats"
|
Pats Module = "pats"
|
||||||
IdentityProviders Module = "identity_providers"
|
IdentityProviders Module = "identity_providers"
|
||||||
|
Services Module = "services"
|
||||||
)
|
)
|
||||||
|
|
||||||
var All = map[Module]struct{}{
|
var All = map[Module]struct{}{
|
||||||
Networks: {},
|
Networks: {},
|
||||||
Peers: {},
|
Peers: {},
|
||||||
RemoteJobs: {},
|
RemoteJobs: {},
|
||||||
Groups: {},
|
Groups: {},
|
||||||
Settings: {},
|
Settings: {},
|
||||||
Accounts: {},
|
Accounts: {},
|
||||||
Dns: {},
|
Dns: {},
|
||||||
Nameservers: {},
|
Nameservers: {},
|
||||||
Events: {},
|
Events: {},
|
||||||
Policies: {},
|
Policies: {},
|
||||||
Routes: {},
|
Routes: {},
|
||||||
Users: {},
|
Users: {},
|
||||||
SetupKeys: {},
|
SetupKeys: {},
|
||||||
Pats: {},
|
Pats: {},
|
||||||
IdentityProviders: {},
|
IdentityProviders: {},
|
||||||
|
Services: {},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
|
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgx/v5/pgxpool"
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"gorm.io/driver/mysql"
|
"gorm.io/driver/mysql"
|
||||||
"gorm.io/driver/postgres"
|
"gorm.io/driver/postgres"
|
||||||
@@ -27,6 +28,9 @@ import (
|
|||||||
"gorm.io/gorm/logger"
|
"gorm.io/gorm/logger"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
@@ -122,11 +126,13 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
|||||||
return nil, fmt.Errorf("migratePreAuto: %w", err)
|
return nil, fmt.Errorf("migratePreAuto: %w", err)
|
||||||
}
|
}
|
||||||
err = db.AutoMigrate(
|
err = db.AutoMigrate(
|
||||||
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
|
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.ProxyAccessToken{},
|
||||||
|
&types.Group{}, &types.GroupPeer{},
|
||||||
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||||
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
||||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
||||||
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{},
|
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &reverseproxy.Service{}, &reverseproxy.Target{}, &domain.Domain{},
|
||||||
|
&accesslogs.AccessLogEntry{},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
|
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
|
||||||
@@ -1094,6 +1100,7 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
|
|||||||
Preload("NetworkRouters").
|
Preload("NetworkRouters").
|
||||||
Preload("NetworkResources").
|
Preload("NetworkResources").
|
||||||
Preload("Onboarding").
|
Preload("Onboarding").
|
||||||
|
Preload("Services.Targets").
|
||||||
Take(&account, idQueryCondition, accountID)
|
Take(&account, idQueryCondition, accountID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
|
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
|
||||||
@@ -1271,6 +1278,17 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
|
|||||||
account.PostureChecks = checks
|
account.PostureChecks = checks
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
services, err := s.getServices(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
account.Services = services
|
||||||
|
}()
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
@@ -1672,7 +1690,7 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
|
|||||||
meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer,
|
meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer,
|
||||||
meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired,
|
meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired,
|
||||||
peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name,
|
peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name,
|
||||||
location_geo_name_id FROM peers WHERE account_id = $1`
|
location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster FROM peers WHERE account_id = $1`
|
||||||
rows, err := s.pool.Query(ctx, query, accountID)
|
rows, err := s.pool.Query(ctx, query, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1685,12 +1703,12 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
|
|||||||
lastLogin, createdAt sql.NullTime
|
lastLogin, createdAt sql.NullTime
|
||||||
sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool
|
sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool
|
||||||
peerStatusLastSeen sql.NullTime
|
peerStatusLastSeen sql.NullTime
|
||||||
peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval sql.NullBool
|
peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval, proxyEmbedded sql.NullBool
|
||||||
ip, extraDNS, netAddr, env, flags, files, connIP []byte
|
ip, extraDNS, netAddr, env, flags, files, connIP []byte
|
||||||
metaHostname, metaGoOS, metaKernel, metaCore, metaPlatform sql.NullString
|
metaHostname, metaGoOS, metaKernel, metaCore, metaPlatform sql.NullString
|
||||||
metaOS, metaOSVersion, metaWtVersion, metaUIVersion, metaKernelVersion sql.NullString
|
metaOS, metaOSVersion, metaWtVersion, metaUIVersion, metaKernelVersion sql.NullString
|
||||||
metaSystemSerialNumber, metaSystemProductName, metaSystemManufacturer sql.NullString
|
metaSystemSerialNumber, metaSystemProductName, metaSystemManufacturer sql.NullString
|
||||||
locationCountryCode, locationCityName sql.NullString
|
locationCountryCode, locationCityName, proxyCluster sql.NullString
|
||||||
locationGeoNameID sql.NullInt64
|
locationGeoNameID sql.NullInt64
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1700,7 +1718,7 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
|
|||||||
&metaOS, &metaOSVersion, &metaWtVersion, &metaUIVersion, &metaKernelVersion, &netAddr,
|
&metaOS, &metaOSVersion, &metaWtVersion, &metaUIVersion, &metaKernelVersion, &netAddr,
|
||||||
&metaSystemSerialNumber, &metaSystemProductName, &metaSystemManufacturer, &env, &flags, &files,
|
&metaSystemSerialNumber, &metaSystemProductName, &metaSystemManufacturer, &env, &flags, &files,
|
||||||
&peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP,
|
&peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP,
|
||||||
&locationCountryCode, &locationCityName, &locationGeoNameID)
|
&locationCountryCode, &locationCityName, &locationGeoNameID, &proxyEmbedded, &proxyCluster)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if lastLogin.Valid {
|
if lastLogin.Valid {
|
||||||
@@ -1784,6 +1802,12 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
|
|||||||
if locationGeoNameID.Valid {
|
if locationGeoNameID.Valid {
|
||||||
p.Location.GeoNameID = uint(locationGeoNameID.Int64)
|
p.Location.GeoNameID = uint(locationGeoNameID.Int64)
|
||||||
}
|
}
|
||||||
|
if proxyEmbedded.Valid {
|
||||||
|
p.ProxyMeta.Embedded = proxyEmbedded.Bool
|
||||||
|
}
|
||||||
|
if proxyCluster.Valid {
|
||||||
|
p.ProxyMeta.Cluster = proxyCluster.String
|
||||||
|
}
|
||||||
if ip != nil {
|
if ip != nil {
|
||||||
_ = json.Unmarshal(ip, &p.IP)
|
_ = json.Unmarshal(ip, &p.IP)
|
||||||
}
|
}
|
||||||
@@ -2039,6 +2063,131 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
|
|||||||
return checks, nil
|
return checks, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||||
|
const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth,
|
||||||
|
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
|
||||||
|
pass_host_header, rewrite_redirects, session_private_key, session_public_key
|
||||||
|
FROM services WHERE account_id = $1`
|
||||||
|
|
||||||
|
const targetsQuery = `SELECT id, account_id, service_id, path, host, port, protocol,
|
||||||
|
target_id, target_type, enabled
|
||||||
|
FROM targets WHERE service_id = ANY($1)`
|
||||||
|
|
||||||
|
serviceRows, err := s.pool.Query(ctx, serviceQuery, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*reverseproxy.Service, error) {
|
||||||
|
var s reverseproxy.Service
|
||||||
|
var auth []byte
|
||||||
|
var createdAt, certIssuedAt sql.NullTime
|
||||||
|
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
|
||||||
|
err := row.Scan(
|
||||||
|
&s.ID,
|
||||||
|
&s.AccountID,
|
||||||
|
&s.Name,
|
||||||
|
&s.Domain,
|
||||||
|
&s.Enabled,
|
||||||
|
&auth,
|
||||||
|
&createdAt,
|
||||||
|
&certIssuedAt,
|
||||||
|
&status,
|
||||||
|
&proxyCluster,
|
||||||
|
&s.PassHostHeader,
|
||||||
|
&s.RewriteRedirects,
|
||||||
|
&sessionPrivateKey,
|
||||||
|
&sessionPublicKey,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if auth != nil {
|
||||||
|
if err := json.Unmarshal(auth, &s.Auth); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Meta = reverseproxy.ServiceMeta{}
|
||||||
|
if createdAt.Valid {
|
||||||
|
s.Meta.CreatedAt = createdAt.Time
|
||||||
|
}
|
||||||
|
if certIssuedAt.Valid {
|
||||||
|
s.Meta.CertificateIssuedAt = certIssuedAt.Time
|
||||||
|
}
|
||||||
|
if status.Valid {
|
||||||
|
s.Meta.Status = status.String
|
||||||
|
}
|
||||||
|
if proxyCluster.Valid {
|
||||||
|
s.ProxyCluster = proxyCluster.String
|
||||||
|
}
|
||||||
|
if sessionPrivateKey.Valid {
|
||||||
|
s.SessionPrivateKey = sessionPrivateKey.String
|
||||||
|
}
|
||||||
|
if sessionPublicKey.Valid {
|
||||||
|
s.SessionPublicKey = sessionPublicKey.String
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Targets = []*reverseproxy.Target{}
|
||||||
|
return &s, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(services) == 0 {
|
||||||
|
return services, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
serviceIDs := make([]string, len(services))
|
||||||
|
serviceMap := make(map[string]*reverseproxy.Service)
|
||||||
|
for i, s := range services {
|
||||||
|
serviceIDs[i] = s.ID
|
||||||
|
serviceMap[s.ID] = s
|
||||||
|
}
|
||||||
|
|
||||||
|
targetRows, err := s.pool.Query(ctx, targetsQuery, serviceIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*reverseproxy.Target, error) {
|
||||||
|
var t reverseproxy.Target
|
||||||
|
var path sql.NullString
|
||||||
|
err := row.Scan(
|
||||||
|
&t.ID,
|
||||||
|
&t.AccountID,
|
||||||
|
&t.ServiceID,
|
||||||
|
&path,
|
||||||
|
&t.Host,
|
||||||
|
&t.Port,
|
||||||
|
&t.Protocol,
|
||||||
|
&t.TargetId,
|
||||||
|
&t.TargetType,
|
||||||
|
&t.Enabled,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if path.Valid {
|
||||||
|
t.Path = &path.String
|
||||||
|
}
|
||||||
|
return &t, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, target := range targets {
|
||||||
|
if service, ok := serviceMap[target.ServiceID]; ok {
|
||||||
|
service.Targets = append(service.Targets, target)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return services, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networkTypes.Network, error) {
|
func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networkTypes.Network, error) {
|
||||||
const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1`
|
const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1`
|
||||||
rows, err := s.pool.Query(ctx, query, accountID)
|
rows, err := s.pool.Query(ctx, query, accountID)
|
||||||
@@ -4230,6 +4379,79 @@ func (s *SqlStore) DeletePAT(ctx context.Context, userID, patID string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetProxyAccessTokenByHashedToken retrieves a proxy access token by its hashed value.
|
||||||
|
func (s *SqlStore) GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error) {
|
||||||
|
tx := s.db.WithContext(ctx)
|
||||||
|
if lockStrength != LockingStrengthNone {
|
||||||
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
|
}
|
||||||
|
|
||||||
|
var token types.ProxyAccessToken
|
||||||
|
result := tx.Take(&token, "hashed_token = ?", hashedToken)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "proxy access token not found")
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(status.Internal, "get proxy access token: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllProxyAccessTokens retrieves all proxy access tokens.
|
||||||
|
func (s *SqlStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error) {
|
||||||
|
tx := s.db.WithContext(ctx)
|
||||||
|
if lockStrength != LockingStrengthNone {
|
||||||
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokens []*types.ProxyAccessToken
|
||||||
|
result := tx.Find(&tokens)
|
||||||
|
if result.Error != nil {
|
||||||
|
return nil, status.Errorf(status.Internal, "get proxy access tokens: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveProxyAccessToken saves a proxy access token to the database.
|
||||||
|
func (s *SqlStore) SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error {
|
||||||
|
if result := s.db.WithContext(ctx).Create(token); result.Error != nil {
|
||||||
|
return status.Errorf(status.Internal, "save proxy access token: %v", result.Error)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RevokeProxyAccessToken revokes a proxy access token by its ID.
|
||||||
|
func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) error {
|
||||||
|
result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).Where(idQueryCondition, tokenID).Update("revoked", true)
|
||||||
|
if result.Error != nil {
|
||||||
|
return status.Errorf(status.Internal, "revoke proxy access token: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.Errorf(status.NotFound, "proxy access token not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token.
|
||||||
|
func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error {
|
||||||
|
result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).
|
||||||
|
Where(idQueryCondition, tokenID).
|
||||||
|
Update("last_used", time.Now().UTC())
|
||||||
|
if result.Error != nil {
|
||||||
|
return status.Errorf(status.Internal, "mark proxy access token as used: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.Errorf(status.NotFound, "proxy access token not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) {
|
func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) {
|
||||||
tx := s.db
|
tx := s.db
|
||||||
if lockStrength != LockingStrengthNone {
|
if lockStrength != LockingStrengthNone {
|
||||||
@@ -4602,3 +4824,353 @@ func (s *SqlStore) GetPeerIDByKey(ctx context.Context, lockStrength LockingStren
|
|||||||
|
|
||||||
return peerID, nil
|
return peerID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Service) error {
|
||||||
|
serviceCopy := service.Copy()
|
||||||
|
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||||
|
return fmt.Errorf("encrypt service data: %w", err)
|
||||||
|
}
|
||||||
|
result := s.db.Create(serviceCopy)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to create service to store: %v", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to create service to store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Service) error {
|
||||||
|
serviceCopy := service.Copy()
|
||||||
|
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||||
|
return fmt.Errorf("encrypt service data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use a transaction to ensure atomic updates of the service and its targets
|
||||||
|
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
// Delete existing targets
|
||||||
|
if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(&reverseproxy.Target{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the service and create new targets
|
||||||
|
if err := tx.Session(&gorm.Session{FullSaveAssociations: true}).Save(serviceCopy).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to update service to store: %v", err)
|
||||||
|
return status.Errorf(status.Internal, "failed to update service to store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID string) error {
|
||||||
|
result := s.db.Delete(&reverseproxy.Service{}, accountAndIDQueryCondition, accountID, serviceID)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to delete service from store: %v", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to delete service from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.Errorf(status.NotFound, "service %s not found", serviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) {
|
||||||
|
tx := s.db.Preload("Targets")
|
||||||
|
if lockStrength != LockingStrengthNone {
|
||||||
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
|
}
|
||||||
|
|
||||||
|
var service *reverseproxy.Service
|
||||||
|
result := tx.Take(&service, accountAndIDQueryCondition, accountID, serviceID)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "service %s not found", serviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Errorf("failed to get service from store: %v", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get service from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||||
|
return nil, fmt.Errorf("decrypt service data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return service, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
|
||||||
|
var service *reverseproxy.Service
|
||||||
|
result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "service with domain %s not found", domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Errorf("failed to get service by domain from store: %v", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get service by domain from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||||
|
return nil, fmt.Errorf("decrypt service data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return service, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) {
|
||||||
|
tx := s.db.Preload("Targets")
|
||||||
|
if lockStrength != LockingStrengthNone {
|
||||||
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
|
}
|
||||||
|
|
||||||
|
var serviceList []*reverseproxy.Service
|
||||||
|
result := tx.Find(&serviceList)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get services from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, service := range serviceList {
|
||||||
|
if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||||
|
return nil, fmt.Errorf("decrypt service data: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return serviceList, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
|
||||||
|
tx := s.db.Preload("Targets")
|
||||||
|
if lockStrength != LockingStrengthNone {
|
||||||
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
|
}
|
||||||
|
|
||||||
|
var serviceList []*reverseproxy.Service
|
||||||
|
result := tx.Find(&serviceList, accountIDCondition, accountID)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get services from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, service := range serviceList {
|
||||||
|
if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||||
|
return nil, fmt.Errorf("decrypt service data: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return serviceList, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) {
|
||||||
|
tx := s.db
|
||||||
|
|
||||||
|
customDomain := &domain.Domain{}
|
||||||
|
result := tx.Take(&customDomain, accountAndIDQueryCondition, accountID, domainID)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "custom domain %s not found", domainID)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Errorf("failed to get custom domain from store: %v", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get custom domain from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return customDomain, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) ListFreeDomains(ctx context.Context, accountID string) ([]string, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error) {
|
||||||
|
tx := s.db
|
||||||
|
|
||||||
|
var domains []*domain.Domain
|
||||||
|
result := tx.Find(&domains, accountIDCondition, accountID)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get reverse proxy custom domains from the store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get reverse proxy custom domains from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return domains, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) CreateCustomDomain(ctx context.Context, accountID string, domainName string, targetCluster string, validated bool) (*domain.Domain, error) {
|
||||||
|
newDomain := &domain.Domain{
|
||||||
|
ID: xid.New().String(), // Generate our own ID because gorm doesn't always configure the database to handle this for us.
|
||||||
|
Domain: domainName,
|
||||||
|
AccountID: accountID,
|
||||||
|
TargetCluster: targetCluster,
|
||||||
|
Type: domain.TypeCustom,
|
||||||
|
Validated: validated,
|
||||||
|
}
|
||||||
|
result := s.db.Create(newDomain)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to create reverse proxy custom domain to store: %v", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to create reverse proxy custom domain to store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return newDomain, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) UpdateCustomDomain(ctx context.Context, accountID string, d *domain.Domain) (*domain.Domain, error) {
|
||||||
|
d.AccountID = accountID
|
||||||
|
result := s.db.Select("*").Save(d)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to update reverse proxy custom domain to store: %v", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to update reverse proxy custom domain to store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error {
|
||||||
|
result := s.db.Delete(domain.Domain{}, accountAndIDQueryCondition, accountID, domainID)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to delete reverse proxy custom domain from store: %v", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to delete reverse proxy custom domain from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.Errorf(status.NotFound, "reverse proxy custom domain %s not found", domainID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateAccessLog creates a new access log entry in the database
|
||||||
|
func (s *SqlStore) CreateAccessLog(ctx context.Context, logEntry *accesslogs.AccessLogEntry) error {
|
||||||
|
result := s.db.Create(logEntry)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).WithFields(log.Fields{
|
||||||
|
"service_id": logEntry.ServiceID,
|
||||||
|
"method": logEntry.Method,
|
||||||
|
"host": logEntry.Host,
|
||||||
|
"path": logEntry.Path,
|
||||||
|
}).Errorf("failed to create access log entry in store: %v", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to create access log entry in store")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountAccessLogs retrieves access logs for a given account with pagination and filtering
|
||||||
|
func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
|
||||||
|
var logs []*accesslogs.AccessLogEntry
|
||||||
|
var totalCount int64
|
||||||
|
|
||||||
|
baseQuery := s.db.WithContext(ctx).
|
||||||
|
Model(&accesslogs.AccessLogEntry{}).
|
||||||
|
Where(accountIDCondition, accountID)
|
||||||
|
|
||||||
|
baseQuery = s.applyAccessLogFilters(baseQuery, filter)
|
||||||
|
|
||||||
|
if err := baseQuery.Count(&totalCount).Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to count access logs: %v", err)
|
||||||
|
return nil, 0, status.Errorf(status.Internal, "failed to count access logs")
|
||||||
|
}
|
||||||
|
|
||||||
|
query := s.db.WithContext(ctx).
|
||||||
|
Where(accountIDCondition, accountID)
|
||||||
|
|
||||||
|
query = s.applyAccessLogFilters(query, filter)
|
||||||
|
|
||||||
|
query = query.
|
||||||
|
Order("timestamp DESC").
|
||||||
|
Limit(filter.GetLimit()).
|
||||||
|
Offset(filter.GetOffset())
|
||||||
|
|
||||||
|
if lockStrength != LockingStrengthNone {
|
||||||
|
query = query.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
|
}
|
||||||
|
|
||||||
|
result := query.Find(&logs)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get access logs from store: %v", result.Error)
|
||||||
|
return nil, 0, status.Errorf(status.Internal, "failed to get access logs from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return logs, totalCount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyAccessLogFilters applies filter conditions to the query
|
||||||
|
func (s *SqlStore) applyAccessLogFilters(query *gorm.DB, filter accesslogs.AccessLogFilter) *gorm.DB {
|
||||||
|
if filter.Search != nil {
|
||||||
|
searchPattern := "%" + *filter.Search + "%"
|
||||||
|
query = query.Where(
|
||||||
|
"id LIKE ? OR location_connection_ip LIKE ? OR host LIKE ? OR path LIKE ? OR CONCAT(host, path) LIKE ? OR user_id IN (SELECT id FROM users WHERE email LIKE ? OR name LIKE ?)",
|
||||||
|
searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if filter.SourceIP != nil {
|
||||||
|
query = query.Where("location_connection_ip = ?", *filter.SourceIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
if filter.Host != nil {
|
||||||
|
query = query.Where("host = ?", *filter.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
if filter.Path != nil {
|
||||||
|
// Support LIKE pattern for path filtering
|
||||||
|
query = query.Where("path LIKE ?", "%"+*filter.Path+"%")
|
||||||
|
}
|
||||||
|
|
||||||
|
if filter.UserID != nil {
|
||||||
|
query = query.Where("user_id = ?", *filter.UserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if filter.Method != nil {
|
||||||
|
query = query.Where("method = ?", *filter.Method)
|
||||||
|
}
|
||||||
|
|
||||||
|
if filter.Status != nil {
|
||||||
|
switch *filter.Status {
|
||||||
|
case "success":
|
||||||
|
query = query.Where("status_code >= ? AND status_code < ?", 200, 400)
|
||||||
|
case "failed":
|
||||||
|
query = query.Where("status_code < ? OR status_code >= ?", 200, 400)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if filter.StatusCode != nil {
|
||||||
|
query = query.Where("status_code = ?", *filter.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if filter.StartDate != nil {
|
||||||
|
query = query.Where("timestamp >= ?", *filter.StartDate)
|
||||||
|
}
|
||||||
|
|
||||||
|
if filter.EndDate != nil {
|
||||||
|
query = query.Where("timestamp <= ?", *filter.EndDate)
|
||||||
|
}
|
||||||
|
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error) {
|
||||||
|
tx := s.db
|
||||||
|
if lockStrength != LockingStrengthNone {
|
||||||
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
|
}
|
||||||
|
|
||||||
|
var target *reverseproxy.Target
|
||||||
|
result := tx.Take(&target, "account_id = ? AND target_id = ?", accountID, targetID)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "service target with ID %s not found", targetID)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Errorf("failed to get service target from store: %v", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get service target from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return target, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
@@ -263,7 +264,7 @@ func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) {
|
|||||||
&types.Policy{}, &types.PolicyRule{}, &route.Route{},
|
&types.Policy{}, &types.PolicyRule{}, &route.Route{},
|
||||||
&nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{},
|
&nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{},
|
||||||
&routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
|
&routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
|
||||||
&types.AccountOnboarding{},
|
&types.AccountOnboarding{}, &reverseproxy.Service{}, &reverseproxy.Target{},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := len(models) - 1; i >= 0; i-- {
|
for i := len(models) - 1; i >= 0; i-- {
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
package store
|
package store
|
||||||
|
|
||||||
|
//go:generate go run github.com/golang/mock/mockgen -package store -destination=store_mock.go -source=./store.go -build_flags=-mod=mod
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
@@ -23,6 +25,9 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/dns"
|
"github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
@@ -106,6 +111,12 @@ type Store interface {
|
|||||||
SavePAT(ctx context.Context, pat *types.PersonalAccessToken) error
|
SavePAT(ctx context.Context, pat *types.PersonalAccessToken) error
|
||||||
DeletePAT(ctx context.Context, userID, patID string) error
|
DeletePAT(ctx context.Context, userID, patID string) error
|
||||||
|
|
||||||
|
GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error)
|
||||||
|
GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error)
|
||||||
|
SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error
|
||||||
|
RevokeProxyAccessToken(ctx context.Context, tokenID string) error
|
||||||
|
MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error
|
||||||
|
|
||||||
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error)
|
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error)
|
||||||
GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error)
|
GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error)
|
||||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
|
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
|
||||||
@@ -240,6 +251,25 @@ type Store interface {
|
|||||||
MarkPendingJobsAsFailed(ctx context.Context, accountID, peerID, jobID, reason string) error
|
MarkPendingJobsAsFailed(ctx context.Context, accountID, peerID, jobID, reason string) error
|
||||||
MarkAllPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error
|
MarkAllPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error
|
||||||
GetPeerIDByKey(ctx context.Context, lockStrength LockingStrength, key string) (string, error)
|
GetPeerIDByKey(ctx context.Context, lockStrength LockingStrength, key string) (string, error)
|
||||||
|
|
||||||
|
CreateService(ctx context.Context, service *reverseproxy.Service) error
|
||||||
|
UpdateService(ctx context.Context, service *reverseproxy.Service) error
|
||||||
|
DeleteService(ctx context.Context, accountID, serviceID string) error
|
||||||
|
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error)
|
||||||
|
GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error)
|
||||||
|
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error)
|
||||||
|
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error)
|
||||||
|
|
||||||
|
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)
|
||||||
|
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
|
||||||
|
ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error)
|
||||||
|
CreateCustomDomain(ctx context.Context, accountID string, domainName string, targetCluster string, validated bool) (*domain.Domain, error)
|
||||||
|
UpdateCustomDomain(ctx context.Context, accountID string, d *domain.Domain) (*domain.Domain, error)
|
||||||
|
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
|
||||||
|
|
||||||
|
CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error
|
||||||
|
GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error)
|
||||||
|
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
2745
management/server/store/store_mock.go
Normal file
2745
management/server/store/store_mock.go
Normal file
File diff suppressed because it is too large
Load Diff
17
management/server/testdata/auth_callback.sql
vendored
Normal file
17
management/server/testdata/auth_callback.sql
vendored
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
-- Schema definitions (must match GORM auto-migrate order)
|
||||||
|
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
|
||||||
|
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||||
|
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||||
|
|
||||||
|
-- Test accounts
|
||||||
|
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||||
|
INSERT INTO accounts VALUES('otherAccountId','','2024-10-02 16:01:38.000000000+00:00','other.com','private',1,'otherNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||||
|
|
||||||
|
-- Test groups
|
||||||
|
INSERT INTO "groups" VALUES('allowedGroupId','testAccountId','Allowed Group','api','[]',0,'');
|
||||||
|
INSERT INTO "groups" VALUES('restrictedGroupId','testAccountId','Restricted Group','api','[]',0,'');
|
||||||
|
|
||||||
|
-- Test users
|
||||||
|
INSERT INTO users VALUES('allowedUserId','testAccountId','user',0,0,'','["allowedGroupId"]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
|
||||||
|
INSERT INTO users VALUES('nonGroupUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
|
||||||
|
INSERT INTO users VALUES('otherAccountUserId','otherAccountId','user',0,0,'','["allowedGroupId"]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
|
||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
@@ -99,6 +100,7 @@ type Account struct {
|
|||||||
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||||
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
|
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
|
||||||
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
|
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
|
||||||
|
Services []*reverseproxy.Service `gorm:"foreignKey:AccountID;references:id"`
|
||||||
// Settings is a dictionary of Account settings
|
// Settings is a dictionary of Account settings
|
||||||
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
|
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
|
||||||
Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"`
|
Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"`
|
||||||
@@ -108,6 +110,8 @@ type Account struct {
|
|||||||
|
|
||||||
NetworkMapCache *NetworkMapBuilder `gorm:"-"`
|
NetworkMapCache *NetworkMapBuilder `gorm:"-"`
|
||||||
nmapInitOnce *sync.Once `gorm:"-"`
|
nmapInitOnce *sync.Once `gorm:"-"`
|
||||||
|
|
||||||
|
ReverseProxyFreeDomainNonce string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Account) InitOnce() {
|
func (a *Account) InitOnce() {
|
||||||
@@ -902,6 +906,11 @@ func (a *Account) Copy() *Account {
|
|||||||
networkResources = append(networkResources, resource.Copy())
|
networkResources = append(networkResources, resource.Copy())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
services := []*reverseproxy.Service{}
|
||||||
|
for _, service := range a.Services {
|
||||||
|
services = append(services, service.Copy())
|
||||||
|
}
|
||||||
|
|
||||||
return &Account{
|
return &Account{
|
||||||
Id: a.Id,
|
Id: a.Id,
|
||||||
CreatedBy: a.CreatedBy,
|
CreatedBy: a.CreatedBy,
|
||||||
@@ -923,6 +932,7 @@ func (a *Account) Copy() *Account {
|
|||||||
Networks: nets,
|
Networks: nets,
|
||||||
NetworkRouters: networkRouters,
|
NetworkRouters: networkRouters,
|
||||||
NetworkResources: networkResources,
|
NetworkResources: networkResources,
|
||||||
|
Services: services,
|
||||||
Onboarding: a.Onboarding,
|
Onboarding: a.Onboarding,
|
||||||
NetworkMapCache: a.NetworkMapCache,
|
NetworkMapCache: a.NetworkMapCache,
|
||||||
nmapInitOnce: a.nmapInitOnce,
|
nmapInitOnce: a.nmapInitOnce,
|
||||||
@@ -1213,7 +1223,7 @@ func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, pe
|
|||||||
filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs))
|
filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs))
|
||||||
for _, p := range uniquePeerIDs {
|
for _, p := range uniquePeerIDs {
|
||||||
peer, ok := a.Peers[p]
|
peer, ok := a.Peers[p]
|
||||||
if !ok || peer == nil {
|
if !ok || peer == nil || peer.ProxyMeta.Embedded {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1776,6 +1786,110 @@ func (a *Account) GetActiveGroupUsers() map[string][]string {
|
|||||||
return groups
|
return groups
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Account) GetProxyPeers() map[string][]*nbpeer.Peer {
|
||||||
|
proxyPeers := make(map[string][]*nbpeer.Peer)
|
||||||
|
for _, peer := range a.Peers {
|
||||||
|
if peer.ProxyMeta.Embedded {
|
||||||
|
proxyPeers[peer.ProxyMeta.Cluster] = append(proxyPeers[peer.ProxyMeta.Cluster], peer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return proxyPeers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) InjectProxyPolicies(ctx context.Context) {
|
||||||
|
if len(a.Services) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyPeersByCluster := a.GetProxyPeers()
|
||||||
|
if len(proxyPeersByCluster) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, service := range a.Services {
|
||||||
|
if !service.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
a.injectServiceProxyPolicies(ctx, service, proxyPeersByCluster)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *reverseproxy.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) {
|
||||||
|
for _, target := range service.Targets {
|
||||||
|
if !target.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
a.injectTargetProxyPolicies(ctx, service, target, proxyPeersByCluster[service.ProxyCluster])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *reverseproxy.Service, target *reverseproxy.Target, proxyPeers []*nbpeer.Peer) {
|
||||||
|
port, ok := a.resolveTargetPort(ctx, target)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
path := ""
|
||||||
|
if target.Path != nil {
|
||||||
|
path = *target.Path
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, proxyPeer := range proxyPeers {
|
||||||
|
policy := a.createProxyPolicy(service, target, proxyPeer, port, path)
|
||||||
|
a.Policies = append(a.Policies, policy)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Target) (int, bool) {
|
||||||
|
if target.Port != 0 {
|
||||||
|
return target.Port, true
|
||||||
|
}
|
||||||
|
|
||||||
|
switch target.Protocol {
|
||||||
|
case "https":
|
||||||
|
return 443, true
|
||||||
|
case "http":
|
||||||
|
return 80, true
|
||||||
|
default:
|
||||||
|
log.WithContext(ctx).Warnf("unsupported protocol %s for proxy target %s, skipping policy injection", target.Protocol, target.TargetId)
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) createProxyPolicy(service *reverseproxy.Service, target *reverseproxy.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy {
|
||||||
|
policyID := fmt.Sprintf("proxy-access-%s-%s-%s", service.ID, proxyPeer.ID, path)
|
||||||
|
return &Policy{
|
||||||
|
ID: policyID,
|
||||||
|
Name: fmt.Sprintf("Proxy Access to %s", service.Name),
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: policyID,
|
||||||
|
PolicyID: policyID,
|
||||||
|
Name: fmt.Sprintf("Allow access to %s", service.Name),
|
||||||
|
Enabled: true,
|
||||||
|
SourceResource: Resource{
|
||||||
|
ID: proxyPeer.ID,
|
||||||
|
Type: ResourceTypePeer,
|
||||||
|
},
|
||||||
|
DestinationResource: Resource{
|
||||||
|
ID: target.TargetId,
|
||||||
|
Type: ResourceType(target.TargetType),
|
||||||
|
},
|
||||||
|
Bidirectional: false,
|
||||||
|
Protocol: PolicyRuleProtocolTCP,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
PortRanges: []RulePortRange{
|
||||||
|
{
|
||||||
|
Start: uint16(port),
|
||||||
|
End: uint16(port),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
|
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
|
||||||
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
|
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
|
||||||
features := peerSupportedFirewallFeatures(peer.Meta.WtVersion)
|
features := peerSupportedFirewallFeatures(peer.Meta.WtVersion)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/dns"
|
"github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
@@ -70,7 +71,7 @@ func TestGetPeerNetworkMap_Golden(t *testing.T) {
|
|||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
|
|
||||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||||
@@ -115,7 +116,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) {
|
|||||||
b.Run("old builder", func(b *testing.B) {
|
b.Run("old builder", func(b *testing.B) {
|
||||||
for range b.N {
|
for range b.N {
|
||||||
for _, peerID := range peerIDs {
|
for _, peerID := range peerIDs {
|
||||||
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -177,7 +178,7 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
|
|||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
|
|
||||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||||
@@ -240,7 +241,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
|
|||||||
b.Run("old builder after add", func(b *testing.B) {
|
b.Run("old builder after add", func(b *testing.B) {
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
for _, testingPeerID := range peerIDs {
|
for _, testingPeerID := range peerIDs {
|
||||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -317,7 +318,7 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
|
|||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
|
|
||||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||||
@@ -402,7 +403,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
|
|||||||
b.Run("old builder after add", func(b *testing.B) {
|
b.Run("old builder after add", func(b *testing.B) {
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
for _, testingPeerID := range peerIDs {
|
for _, testingPeerID := range peerIDs {
|
||||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -458,7 +459,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
|
|||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
|
|
||||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||||
@@ -537,7 +538,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
|
|||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
|
|
||||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||||
@@ -597,7 +598,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
|
|||||||
b.Run("old builder after delete", func(b *testing.B) {
|
b.Run("old builder after delete", func(b *testing.B) {
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
for _, testingPeerID := range peerIDs {
|
for _, testingPeerID := range peerIDs {
|
||||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user