mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 14:44:34 -04:00
Compare commits
50 Commits
feature/ke
...
ebpf-debug
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d02b8cbc97 | ||
|
|
908a27a047 | ||
|
|
a3839a6ef7 | ||
|
|
8aa4f240c7 | ||
|
|
d9686bae92 | ||
|
|
24e19ae287 | ||
|
|
74fde0ea2c | ||
|
|
890e09b787 | ||
|
|
48098c994d | ||
|
|
64f6343fcc | ||
|
|
24713fbe59 | ||
|
|
7794b744f8 | ||
|
|
0d0c30c16d | ||
|
|
b0364da67c | ||
|
|
6dee89379b | ||
|
|
76db4f801a | ||
|
|
6c2ed4b4f2 | ||
|
|
2541c78dd0 | ||
|
|
97b6e79809 | ||
|
|
6ad3847615 | ||
|
|
a4d830ef83 | ||
|
|
9e540cd5b4 | ||
|
|
3027d8f27e | ||
|
|
e69ec6ab6a | ||
|
|
7ddde41c92 | ||
|
|
7ebe58f20a | ||
|
|
9c2c0e7934 | ||
|
|
c6af1037d9 | ||
|
|
5cb9a126f1 | ||
|
|
f40951cdf5 | ||
|
|
6e264d9de7 | ||
|
|
42db9773f4 | ||
|
|
bb9f6f6d0a | ||
|
|
829ce6573e | ||
|
|
a366d9e208 | ||
|
|
e074c24487 | ||
|
|
54fe05f6d8 | ||
|
|
33a155d9aa | ||
|
|
51878659f8 | ||
|
|
c000c05435 | ||
|
|
b39ffef22c | ||
|
|
d96f882acb | ||
|
|
d409219b51 | ||
|
|
8b619a8224 | ||
|
|
ed075bc9b9 | ||
|
|
8eb098d6fd | ||
|
|
68a8687c80 | ||
|
|
f7d97b02fd | ||
|
|
2691e729cd | ||
|
|
b524a9d49d |
12
.github/workflows/release.yml
vendored
12
.github/workflows/release.yml
vendored
@@ -7,6 +7,16 @@ on:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
paths:
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.goreleaser.yml'
|
||||
- '.goreleaser_ui.yaml'
|
||||
- '.goreleaser_ui_darwin.yaml'
|
||||
- '.github/workflows/release.yml'
|
||||
- 'release_files/**'
|
||||
- '**/Dockerfile'
|
||||
- '**/Dockerfile.*'
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.8"
|
||||
@@ -116,7 +126,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-mingw-w64-x86-64
|
||||
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
||||
- name: Install rsrc
|
||||
run: go install github.com/akavel/rsrc@v0.10.2
|
||||
- name: Generate windows rsrc
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
name: Test Docker Compose Linux
|
||||
name: Test Infrastructure files
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
paths:
|
||||
- 'infrastructure_files/**'
|
||||
- '.github/workflows/test-infrastructure-files.yml'
|
||||
|
||||
|
||||
concurrency:
|
||||
@@ -12,7 +15,7 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
test-docker-compose:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install jq
|
||||
@@ -35,7 +38,7 @@ jobs:
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: cp setup.env
|
||||
run: cp infrastructure_files/tests/setup.env infrastructure_files/
|
||||
@@ -53,6 +56,7 @@ jobs:
|
||||
CI_NETBIRD_MGMT_IDP: "none"
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
||||
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
|
||||
|
||||
- name: check values
|
||||
working-directory: infrastructure_files
|
||||
@@ -68,6 +72,7 @@ jobs:
|
||||
CI_NETBIRD_AUTH_JWT_CERTS: https://example.eu.auth0.com/.well-known/jwks.json
|
||||
CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token
|
||||
CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code
|
||||
CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT: https://example.eu.auth0.com/authorize
|
||||
CI_NETBIRD_AUTH_REDIRECT_URI: "/peers"
|
||||
CI_NETBIRD_TOKEN_SOURCE: "idToken"
|
||||
CI_NETBIRD_AUTH_USER_ID_CLAIM: "email"
|
||||
@@ -90,8 +95,8 @@ jobs:
|
||||
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
|
||||
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
|
||||
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
|
||||
grep -A 1 ProviderConfig management.json | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
|
||||
grep Scope management.json | grep "$CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE"
|
||||
grep -A 3 DeviceAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
|
||||
grep -A 8 DeviceAuthorizationFlow management.json | grep -A 6 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE"
|
||||
grep UseIDToken management.json | grep false
|
||||
grep -A 1 IdpManagerConfig management.json | grep ManagerType | grep $CI_NETBIRD_MGMT_IDP
|
||||
grep -A 3 IdpManagerConfig management.json | grep -A 1 ClientConfig | grep Issuer | grep $CI_NETBIRD_AUTH_AUTHORITY
|
||||
@@ -99,6 +104,12 @@ jobs:
|
||||
grep -A 5 IdpManagerConfig management.json | grep -A 3 ClientConfig | grep ClientID | grep $CI_NETBIRD_IDP_MGMT_CLIENT_ID
|
||||
grep -A 6 IdpManagerConfig management.json | grep -A 4 ClientConfig | grep ClientSecret | grep $CI_NETBIRD_IDP_MGMT_CLIENT_SECRET
|
||||
grep -A 7 IdpManagerConfig management.json | grep -A 5 ClientConfig | grep GrantType | grep client_credentials
|
||||
grep -A 2 PKCEAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_AUDIENCE
|
||||
grep -A 3 PKCEAuthorizationFlow management.json | grep -A 2 ProviderConfig | grep ClientID | grep $CI_NETBIRD_AUTH_CLIENT_ID
|
||||
grep -A 4 PKCEAuthorizationFlow management.json | grep -A 3 ProviderConfig | grep ClientSecret | grep $CI_NETBIRD_AUTH_CLIENT_SECRET
|
||||
grep -A 5 PKCEAuthorizationFlow management.json | grep -A 4 ProviderConfig | grep AuthorizationEndpoint | grep $CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
|
||||
grep -A 6 PKCEAuthorizationFlow management.json | grep -A 5 ProviderConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
|
||||
grep -A 7 PKCEAuthorizationFlow management.json | grep -A 6 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
|
||||
|
||||
- name: run docker compose up
|
||||
working-directory: infrastructure_files
|
||||
@@ -113,3 +124,28 @@ jobs:
|
||||
count=$(docker compose ps --format json | jq '.[] | select(.Project | contains("infrastructure_files")) | .State' | grep -c running)
|
||||
test $count -eq 4
|
||||
working-directory: infrastructure_files
|
||||
|
||||
test-getting-started-script:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install jq
|
||||
run: sudo apt-get install -y jq
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: run script
|
||||
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
|
||||
|
||||
- name: test Caddy file gen
|
||||
run: test -f Caddyfile
|
||||
- name: test docker-compose file gen
|
||||
run: test -f docker-compose.yml
|
||||
- name: test management.json file gen
|
||||
run: test -f management.json
|
||||
- name: test turnserver.conf file gen
|
||||
run: test -f turnserver.conf
|
||||
- name: test zitadel.env file gen
|
||||
run: test -f zitadel.env
|
||||
- name: test dashboard.env file gen
|
||||
run: test -f dashboard.env
|
||||
@@ -377,3 +377,13 @@ uploads:
|
||||
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
|
||||
username: dev@wiretrustee.com
|
||||
method: PUT
|
||||
|
||||
checksum:
|
||||
extra_files:
|
||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||
- glob: ./release_files/install.sh
|
||||
|
||||
release:
|
||||
extra_files:
|
||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||
- glob: ./release_files/install.sh
|
||||
@@ -11,6 +11,8 @@ builds:
|
||||
- amd64
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
tags:
|
||||
- legacy_appindicator
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
|
||||
- id: netbird-ui-windows
|
||||
@@ -55,9 +57,6 @@ nfpms:
|
||||
- src: client/ui/disconnected.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- libayatana-appindicator3-1
|
||||
- libgtk-3-dev
|
||||
- libappindicator3-dev
|
||||
- netbird
|
||||
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
@@ -75,9 +74,6 @@ nfpms:
|
||||
- src: client/ui/disconnected.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- libayatana-appindicator3-1
|
||||
- libgtk-3-dev
|
||||
- libappindicator3-dev
|
||||
- netbird
|
||||
|
||||
uploads:
|
||||
|
||||
10
README.md
10
README.md
@@ -59,7 +59,7 @@ NetBird uses [NAT traversal techniques](https://en.wikipedia.org/wiki/Interactiv
|
||||
- \[x] Network Activity Monitoring.
|
||||
|
||||
**Coming soon:**
|
||||
- \[ ] Mobile clients.
|
||||
- \[ ] Mobile clients.
|
||||
|
||||
### Secure peer-to-peer VPN with SSO and MFA in minutes
|
||||
|
||||
@@ -70,9 +70,9 @@ For stable versions, see [releases](https://github.com/netbirdio/netbird/release
|
||||
|
||||
### Start using NetBird
|
||||
- Hosted version: [https://app.netbird.io/](https://app.netbird.io/).
|
||||
- See our documentation for [Quickstart Guide](https://netbird.io/docs/getting-started/quickstart).
|
||||
- If you are looking to self-host NetBird, check our [Self-Hosting Guide](https://netbird.io/docs/getting-started/self-hosting).
|
||||
- Step-by-step [Installation Guide](https://netbird.io/docs/getting-started/installation) for different platforms.
|
||||
- See our documentation for [Quickstart Guide](https://docs.netbird.io/how-to/getting-started).
|
||||
- If you are looking to self-host NetBird, check our [Self-Hosting Guide](https://docs.netbird.io/selfhosted/selfhosted-guide).
|
||||
- Step-by-step [Installation Guide](https://docs.netbird.io/how-to/getting-started#installation) for different platforms.
|
||||
- Web UI [repository](https://github.com/netbirdio/dashboard).
|
||||
- 5 min [demo video](https://youtu.be/Tu9tPsUWaY0) on YouTube.
|
||||
|
||||
@@ -91,7 +91,7 @@ For stable versions, see [releases](https://github.com/netbirdio/netbird/release
|
||||
<img src="https://netbird.io/docs/img/architecture/high-level-dia.png" width="700"/>
|
||||
</p>
|
||||
|
||||
See a complete [architecture overview](https://netbird.io/docs/overview/architecture) for details.
|
||||
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
||||
|
||||
### Roadmap
|
||||
- [Public Roadmap](https://github.com/netbirdio/netbird/projects/2)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
@@ -35,6 +36,11 @@ type RouteListener interface {
|
||||
routemanager.RouteListener
|
||||
}
|
||||
|
||||
// DnsReadyListener export internal dns ReadyListener for mobile
|
||||
type DnsReadyListener interface {
|
||||
dns.ReadyListener
|
||||
}
|
||||
|
||||
func init() {
|
||||
formatter.SetLogcatFormatter(log.StandardLogger())
|
||||
}
|
||||
@@ -49,6 +55,7 @@ type Client struct {
|
||||
ctxCancelLock *sync.Mutex
|
||||
deviceName string
|
||||
routeListener routemanager.RouteListener
|
||||
onHostDnsFn func([]string)
|
||||
}
|
||||
|
||||
// NewClient instantiate a new Client
|
||||
@@ -65,7 +72,7 @@ func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover
|
||||
}
|
||||
|
||||
// Run start the internal client. It is a blocker function
|
||||
func (c *Client) Run(urlOpener URLOpener) error {
|
||||
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
|
||||
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
||||
ConfigPath: c.cfgFile,
|
||||
})
|
||||
@@ -90,7 +97,8 @@ func (c *Client) Run(urlOpener URLOpener) error {
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener)
|
||||
c.onHostDnsFn = func([]string) {}
|
||||
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
@@ -126,6 +134,17 @@ func (c *Client) PeersList() *PeerInfoArray {
|
||||
return &PeerInfoArray{items: peerInfos}
|
||||
}
|
||||
|
||||
// OnUpdatedHostDNS update the DNS servers addresses for root zones
|
||||
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
||||
dnsServer, err := dns.GetServerDns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dnsServer.OnUpdatedHostDNSServer(list.items)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetConnectionListener set the network connection listener
|
||||
func (c *Client) SetConnectionListener(listener ConnectionListener) {
|
||||
c.recorder.SetConnectionListener(listener)
|
||||
|
||||
26
client/android/dns_list.go
Normal file
26
client/android/dns_list.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package android
|
||||
|
||||
import "fmt"
|
||||
|
||||
// DNSList is a wrapper of []string
|
||||
type DNSList struct {
|
||||
items []string
|
||||
}
|
||||
|
||||
// Add new DNS address to the collection
|
||||
func (array *DNSList) Add(s string) {
|
||||
array.items = append(array.items, s)
|
||||
}
|
||||
|
||||
// Get return an element of the collection
|
||||
func (array *DNSList) Get(i int) (string, error) {
|
||||
if i >= len(array.items) || i < 0 {
|
||||
return "", fmt.Errorf("out of range")
|
||||
}
|
||||
return array.items[i], nil
|
||||
}
|
||||
|
||||
// Size return with the size of the collection
|
||||
func (array *DNSList) Size() int {
|
||||
return len(array.items)
|
||||
}
|
||||
24
client/android/dns_list_test.go
Normal file
24
client/android/dns_list_test.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package android
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDNSList_Get(t *testing.T) {
|
||||
l := DNSList{
|
||||
items: make([]string, 1),
|
||||
}
|
||||
|
||||
_, err := l.Get(0)
|
||||
if err != nil {
|
||||
t.Errorf("invalid error: %s", err)
|
||||
}
|
||||
|
||||
_, err = l.Get(-1)
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got nil")
|
||||
}
|
||||
|
||||
_, err = l.Get(1)
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got nil")
|
||||
}
|
||||
}
|
||||
@@ -6,15 +6,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/cmd"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
)
|
||||
|
||||
// SSOListener is async listener for mobile framework
|
||||
@@ -87,9 +86,15 @@ func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
supportsSSO = false
|
||||
err = nil
|
||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
supportsSSO = false
|
||||
err = nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
@@ -183,27 +188,15 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*internal.TokenInfo, error) {
|
||||
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config)
|
||||
if err != nil {
|
||||
s, ok := gstatus.FromError(err)
|
||||
if ok && s.Code() == codes.NotFound {
|
||||
return nil, fmt.Errorf("no SSO provider returned from management. " +
|
||||
"If you are using hosting Netbird see documentation at " +
|
||||
"https://github.com/netbirdio/netbird/tree/main/management for details")
|
||||
} else if ok && s.Code() == codes.Unimplemented {
|
||||
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
||||
"please update your servver or use Setup Keys to login", a.config.ManagementURL)
|
||||
} else {
|
||||
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig)
|
||||
|
||||
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
|
||||
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting a request device code failed: %v", err)
|
||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||
}
|
||||
|
||||
go urlOpener.Open(flowInfo.VerificationURIComplete)
|
||||
@@ -211,7 +204,7 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*internal.TokenInfo,
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn)
|
||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout*time.Second)
|
||||
defer cancel()
|
||||
tokenInfo, err := hostedClient.WaitToken(waitCTX, flowInfo)
|
||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -163,31 +164,15 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
||||
return nil
|
||||
}
|
||||
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*internal.TokenInfo, error) {
|
||||
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config)
|
||||
if err != nil {
|
||||
s, ok := gstatus.FromError(err)
|
||||
if ok && s.Code() == codes.NotFound {
|
||||
return nil, fmt.Errorf("no SSO provider returned from management. " +
|
||||
"If you are using hosting Netbird see documentation at " +
|
||||
"https://github.com/netbirdio/netbird/tree/main/management for details")
|
||||
} else if ok && s.Code() == codes.Unimplemented {
|
||||
mgmtURL := managementURL
|
||||
if mgmtURL == "" {
|
||||
mgmtURL = internal.DefaultManagementURL
|
||||
}
|
||||
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
||||
"please update your servver or use Setup Keys to login", mgmtURL)
|
||||
} else {
|
||||
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig)
|
||||
|
||||
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
|
||||
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting a request device code failed: %v", err)
|
||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||
}
|
||||
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||
@@ -196,7 +181,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout*time.Second)
|
||||
defer c()
|
||||
|
||||
tokenInfo, err := hostedClient.WaitToken(waitCTX, flowInfo)
|
||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||
}
|
||||
@@ -206,8 +191,10 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
||||
|
||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
|
||||
var codeMsg string
|
||||
if !strings.Contains(verificationURIComplete, userCode) {
|
||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||
if userCode != "" {
|
||||
if !strings.Contains(verificationURIComplete, userCode) {
|
||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||
}
|
||||
}
|
||||
|
||||
err := open.Run(verificationURIComplete)
|
||||
|
||||
@@ -73,7 +73,8 @@ var sshCmd = &cobra.Command{
|
||||
go func() {
|
||||
// blocking
|
||||
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
||||
log.Print(err)
|
||||
log.Debug(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
@@ -92,12 +93,10 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
|
||||
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
|
||||
if err != nil {
|
||||
cmd.Printf("Error: %v\n", err)
|
||||
cmd.Printf("Couldn't connect. " +
|
||||
"You might be disconnected from the NetBird network, or the NetBird agent isn't running.\n" +
|
||||
"Run the status command: \n\n" +
|
||||
" netbird status\n\n" +
|
||||
"It might also be that the SSH server is disabled on the agent you are trying to connect to.\n")
|
||||
return nil
|
||||
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
|
||||
"You can verify the connection by running:\n\n" +
|
||||
" netbird status\n\n")
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
|
||||
@@ -104,7 +104,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithCancel(ctx)
|
||||
SetupCloseHandler(ctx, cancel)
|
||||
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil, nil, nil)
|
||||
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
|
||||
}
|
||||
|
||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
|
||||
@@ -51,6 +51,7 @@ type Manager interface {
|
||||
dPort *Port,
|
||||
direction RuleDirection,
|
||||
action Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) (Rule, error)
|
||||
|
||||
@@ -60,5 +61,8 @@ type Manager interface {
|
||||
// Reset firewall to the default state
|
||||
Reset() error
|
||||
|
||||
// Flush the changes to firewall controller
|
||||
Flush() error
|
||||
|
||||
// TODO: migrate routemanager firewal actions to this interface
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/uuid"
|
||||
"github.com/nadoo/ipset"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
@@ -35,6 +36,8 @@ type Manager struct {
|
||||
inputDefaultRuleSpecs []string
|
||||
outputDefaultRuleSpecs []string
|
||||
wgIface iFaceMapper
|
||||
|
||||
rulesets map[string]ruleset
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
@@ -43,6 +46,11 @@ type iFaceMapper interface {
|
||||
Address() iface.WGAddress
|
||||
}
|
||||
|
||||
type ruleset struct {
|
||||
rule *Rule
|
||||
ips map[string]string
|
||||
}
|
||||
|
||||
// Create iptables firewall manager
|
||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
m := &Manager{
|
||||
@@ -51,6 +59,11 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
"-i", wgIface.Name(), "-j", ChainInputFilterName, "-s", wgIface.Address().String()},
|
||||
outputDefaultRuleSpecs: []string{
|
||||
"-o", wgIface.Name(), "-j", ChainOutputFilterName, "-d", wgIface.Address().String()},
|
||||
rulesets: make(map[string]ruleset),
|
||||
}
|
||||
|
||||
if err := ipset.Init(); err != nil {
|
||||
return nil, fmt.Errorf("init ipset: %w", err)
|
||||
}
|
||||
|
||||
// init clients for booth ipv4 and ipv6
|
||||
@@ -58,13 +71,17 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
||||
}
|
||||
m.ipv4Client = ipv4Client
|
||||
if isIptablesClientAvailable(ipv4Client) {
|
||||
m.ipv4Client = ipv4Client
|
||||
}
|
||||
|
||||
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
if err != nil {
|
||||
log.Errorf("ip6tables is not installed in the system or not supported: %v", err)
|
||||
} else {
|
||||
m.ipv6Client = ipv6Client
|
||||
if isIptablesClientAvailable(ipv6Client) {
|
||||
m.ipv6Client = ipv6Client
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.Reset(); err != nil {
|
||||
@@ -73,6 +90,11 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||
_, err := client.ListChains("filter")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment is empty rule ID is used as comment
|
||||
@@ -83,6 +105,7 @@ func (m *Manager) AddFiltering(
|
||||
dPort *fw.Port,
|
||||
direction fw.RuleDirection,
|
||||
action fw.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) (fw.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
@@ -101,22 +124,45 @@ func (m *Manager) AddFiltering(
|
||||
if sPort != nil && sPort.Values != nil {
|
||||
sPortVal = strconv.Itoa(sPort.Values[0])
|
||||
}
|
||||
ipsetName = m.transformIPsetName(ipsetName, sPortVal, dPortVal)
|
||||
|
||||
ruleID := uuid.New().String()
|
||||
if comment == "" {
|
||||
comment = ruleID
|
||||
}
|
||||
|
||||
specs := m.filterRuleSpecs(
|
||||
"filter",
|
||||
ip,
|
||||
string(protocol),
|
||||
sPortVal,
|
||||
dPortVal,
|
||||
direction,
|
||||
action,
|
||||
comment,
|
||||
)
|
||||
if ipsetName != "" {
|
||||
rs, rsExists := m.rulesets[ipsetName]
|
||||
if !rsExists {
|
||||
if err := ipset.Flush(ipsetName); err != nil {
|
||||
log.Errorf("flush ipset %q before use it: %v", ipsetName, err)
|
||||
}
|
||||
if err := ipset.Create(ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("failed to create ipset: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
||||
}
|
||||
|
||||
if rsExists {
|
||||
// if ruleset already exists it means we already have the firewall rule
|
||||
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
||||
rs.ips[ip.String()] = ruleID
|
||||
return &Rule{
|
||||
ruleID: ruleID,
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
dst: direction == fw.RuleDirectionOUT,
|
||||
v6: ip.To4() == nil,
|
||||
}, nil
|
||||
}
|
||||
// this is new ipset so we need to create firewall rule for it
|
||||
}
|
||||
|
||||
specs := m.filterRuleSpecs("filter", ip, string(protocol), sPortVal, dPortVal,
|
||||
direction, action, comment, ipsetName)
|
||||
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
ok, err := client.Exists("filter", ChainOutputFilterName, specs...)
|
||||
@@ -144,12 +190,24 @@ func (m *Manager) AddFiltering(
|
||||
}
|
||||
}
|
||||
|
||||
return &Rule{
|
||||
id: ruleID,
|
||||
specs: specs,
|
||||
dst: direction == fw.RuleDirectionOUT,
|
||||
v6: ip.To4() == nil,
|
||||
}, nil
|
||||
rule := &Rule{
|
||||
ruleID: ruleID,
|
||||
specs: specs,
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
dst: direction == fw.RuleDirectionOUT,
|
||||
v6: ip.To4() == nil,
|
||||
}
|
||||
if ipsetName != "" {
|
||||
// ipset name is defined and it means that this rule was created
|
||||
// for it, need to assosiate it with ruleset
|
||||
m.rulesets[ipsetName] = ruleset{
|
||||
rule: rule,
|
||||
ips: map[string]string{rule.ip: ruleID},
|
||||
}
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
@@ -170,6 +228,31 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||
client = m.ipv6Client
|
||||
}
|
||||
|
||||
if rs, ok := m.rulesets[r.ipsetName]; ok {
|
||||
// delete IP from ruleset IPs list and ipset
|
||||
if _, ok := rs.ips[r.ip]; ok {
|
||||
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
|
||||
return fmt.Errorf("failed to delete ip from ipset: %w", err)
|
||||
}
|
||||
delete(rs.ips, r.ip)
|
||||
}
|
||||
|
||||
// if after delete, set still contains other IPs,
|
||||
// no need to delete firewall rule and we should exit here
|
||||
if len(rs.ips) != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// we delete last IP from the set, that means we need to delete
|
||||
// set itself and assosiated firewall rule too
|
||||
delete(m.rulesets, r.ipsetName)
|
||||
|
||||
if err := ipset.Destroy(r.ipsetName); err != nil {
|
||||
log.Errorf("delete empty ipset: %v", err)
|
||||
}
|
||||
r = rs.rule
|
||||
}
|
||||
|
||||
if r.dst {
|
||||
return client.Delete("filter", ChainOutputFilterName, r.specs...)
|
||||
}
|
||||
@@ -193,6 +276,9 @@ func (m *Manager) Reset() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// reset firewall chain, clear it and drop it
|
||||
func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
||||
ok, err := client.ChainExists(table, ChainInputFilterName)
|
||||
@@ -233,6 +319,16 @@ func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
for ipsetName := range m.rulesets {
|
||||
if err := ipset.Flush(ipsetName); err != nil {
|
||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
if err := ipset.Destroy(ipsetName); err != nil {
|
||||
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
delete(m.rulesets, ipsetName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -240,6 +336,7 @@ func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
||||
func (m *Manager) filterRuleSpecs(
|
||||
table string, ip net.IP, protocol string, sPort, dPort string,
|
||||
direction fw.RuleDirection, action fw.Action, comment string,
|
||||
ipsetName string,
|
||||
) (specs []string) {
|
||||
matchByIP := true
|
||||
// don't use IP matching if IP is ip 0.0.0.0
|
||||
@@ -249,11 +346,19 @@ func (m *Manager) filterRuleSpecs(
|
||||
switch direction {
|
||||
case fw.RuleDirectionIN:
|
||||
if matchByIP {
|
||||
specs = append(specs, "-s", ip.String())
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
||||
} else {
|
||||
specs = append(specs, "-s", ip.String())
|
||||
}
|
||||
}
|
||||
case fw.RuleDirectionOUT:
|
||||
if matchByIP {
|
||||
specs = append(specs, "-d", ip.String())
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
|
||||
} else {
|
||||
specs = append(specs, "-d", ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
if protocol != "all" {
|
||||
@@ -335,3 +440,16 @@ func (m *Manager) actionToStr(action fw.Action) string {
|
||||
}
|
||||
return "DROP"
|
||||
}
|
||||
|
||||
func (m *Manager) transformIPsetName(ipsetName string, sPort, dPort string) string {
|
||||
if ipsetName == "" {
|
||||
return ""
|
||||
} else if sPort != "" && dPort != "" {
|
||||
return ipsetName + "-sport-dport"
|
||||
} else if sPort != "" {
|
||||
return ipsetName + "-sport"
|
||||
} else if dPort != "" {
|
||||
return ipsetName + "-dport"
|
||||
}
|
||||
return ipsetName
|
||||
}
|
||||
|
||||
@@ -55,12 +55,13 @@ func TestIptablesManager(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
if err := manager.Reset(); err != nil {
|
||||
t.Errorf("clear the manager state: %v", err)
|
||||
}
|
||||
err := manager.Reset()
|
||||
require.NoError(t, err, "clear the manager state")
|
||||
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
@@ -68,7 +69,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
t.Run("add first rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Values: []int{8080}}
|
||||
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
||||
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
||||
@@ -81,33 +82,31 @@ func TestIptablesManager(t *testing.T) {
|
||||
Values: []int{8043: 8046},
|
||||
}
|
||||
rule2, err = manager.AddFiltering(
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTPS traffic from ports range")
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...)
|
||||
})
|
||||
|
||||
t.Run("delete first rule", func(t *testing.T) {
|
||||
if err := manager.DeleteRule(rule1); err != nil {
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
err := manager.DeleteRule(rule1)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, false, rule1.(*Rule).specs...)
|
||||
})
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
if err := manager.DeleteRule(rule2); err != nil {
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
err := manager.DeleteRule(rule2)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, false, rule2.(*Rule).specs...)
|
||||
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
|
||||
})
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
// add second rule
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{Values: []int{5353}}
|
||||
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept Fake DNS traffic")
|
||||
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Reset()
|
||||
@@ -122,6 +121,88 @@ func TestIptablesManager(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestIptablesManagerIPSet(t *testing.T) {
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
mock := &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
err := manager.Reset()
|
||||
require.NoError(t, err, "clear the manager state")
|
||||
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
var rule1 fw.Rule
|
||||
t.Run("add first rule with set", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Values: []int{8080}}
|
||||
rule1, err = manager.AddFiltering(
|
||||
ip, "tcp", nil, port, fw.RuleDirectionOUT,
|
||||
fw.ActionAccept, "default", "accept HTTP traffic",
|
||||
)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
||||
require.Equal(t, rule1.(*Rule).ipsetName, "default-dport", "ipset name must be set")
|
||||
require.Equal(t, rule1.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
|
||||
})
|
||||
|
||||
var rule2 fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
Values: []int{443},
|
||||
}
|
||||
rule2, err = manager.AddFiltering(
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
|
||||
"default", "accept HTTPS traffic from ports range",
|
||||
)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
require.Equal(t, rule2.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||
require.Equal(t, rule2.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
|
||||
})
|
||||
|
||||
t.Run("delete first rule", func(t *testing.T) {
|
||||
err := manager.DeleteRule(rule1)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
require.NotContains(t, manager.rulesets, rule1.(*Rule).ruleID, "rule must be removed form the ruleset index")
|
||||
})
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
err := manager.DeleteRule(rule2)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
|
||||
})
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "failed to reset")
|
||||
})
|
||||
}
|
||||
|
||||
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) {
|
||||
exists, err := ipv4Client.Exists("filter", chainName, rulespec...)
|
||||
require.NoError(t, err, "failed to check rule")
|
||||
@@ -153,9 +234,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
if err := manager.Reset(); err != nil {
|
||||
t.Errorf("clear the manager state: %v", err)
|
||||
}
|
||||
err := manager.Reset()
|
||||
require.NoError(t, err, "clear the manager state")
|
||||
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
@@ -167,9 +248,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
@@ -2,13 +2,16 @@ package iptables
|
||||
|
||||
// Rule to handle management of rules
|
||||
type Rule struct {
|
||||
id string
|
||||
ruleID string
|
||||
ipsetName string
|
||||
|
||||
specs []string
|
||||
ip string
|
||||
dst bool
|
||||
v6 bool
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *Rule) GetRuleID() string {
|
||||
return r.id
|
||||
return r.ruleID
|
||||
}
|
||||
|
||||
@@ -6,12 +6,14 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
@@ -29,11 +31,14 @@ const (
|
||||
FilterOutputChainName = "netbird-acl-output-filter"
|
||||
)
|
||||
|
||||
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
|
||||
// Manager of iptables firewall
|
||||
type Manager struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
conn *nftables.Conn
|
||||
rConn *nftables.Conn
|
||||
sConn *nftables.Conn
|
||||
tableIPv4 *nftables.Table
|
||||
tableIPv6 *nftables.Table
|
||||
|
||||
@@ -43,6 +48,10 @@ type Manager struct {
|
||||
filterInputChainIPv6 *nftables.Chain
|
||||
filterOutputChainIPv6 *nftables.Chain
|
||||
|
||||
rulesetManager *rulesetManager
|
||||
setRemovedIPs map[string]struct{}
|
||||
setRemoved map[string]*nftables.Set
|
||||
|
||||
wgIface iFaceMapper
|
||||
}
|
||||
|
||||
@@ -54,8 +63,23 @@ type iFaceMapper interface {
|
||||
|
||||
// Create nftables firewall manager
|
||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
// sConn is used for creating sets and adding/removing elements from them
|
||||
// it's differ then rConn (which does create new conn for each flush operation)
|
||||
// and is permanent. Using same connection for booth type of operations
|
||||
// overloads netlink with high amount of rules ( > 10000)
|
||||
sConn, err := nftables.New(nftables.AsLasting())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := &Manager{
|
||||
conn: &nftables.Conn{},
|
||||
rConn: &nftables.Conn{},
|
||||
sConn: sConn,
|
||||
|
||||
rulesetManager: newRuleManager(),
|
||||
setRemovedIPs: map[string]struct{}{},
|
||||
setRemoved: map[string]*nftables.Set{},
|
||||
|
||||
wgIface: wgIface,
|
||||
}
|
||||
|
||||
@@ -77,6 +101,7 @@ func (m *Manager) AddFiltering(
|
||||
dPort *fw.Port,
|
||||
direction fw.RuleDirection,
|
||||
action fw.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) (fw.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
@@ -84,6 +109,7 @@ func (m *Manager) AddFiltering(
|
||||
|
||||
var (
|
||||
err error
|
||||
ipset *nftables.Set
|
||||
table *nftables.Table
|
||||
chain *nftables.Chain
|
||||
)
|
||||
@@ -107,6 +133,46 @@ func (m *Manager) AddFiltering(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawIP := ip.To4()
|
||||
if rawIP == nil {
|
||||
rawIP = ip.To16()
|
||||
}
|
||||
|
||||
rulesetID := m.getRulesetID(ip, proto, sPort, dPort, direction, action, ipsetName)
|
||||
|
||||
if ipsetName != "" {
|
||||
// if we already have set with given name, just add ip to the set
|
||||
// and return rule with new ID in other case let's create rule
|
||||
// with fresh created set and set element
|
||||
|
||||
var isSetNew bool
|
||||
ipset, err = m.rConn.GetSetByName(table, ipsetName)
|
||||
if err != nil {
|
||||
if ipset, err = m.createSet(table, rawIP, ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("get set name: %v", err)
|
||||
}
|
||||
isSetNew = true
|
||||
}
|
||||
|
||||
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
|
||||
return nil, fmt.Errorf("add set element for the first time: %v", err)
|
||||
}
|
||||
if err := m.sConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush add elements: %v", err)
|
||||
}
|
||||
|
||||
if !isSetNew {
|
||||
// if we already have nftables rules with set for given direction
|
||||
// just add new rule to the ruleset and return new fw.Rule object
|
||||
|
||||
if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok {
|
||||
return m.rulesetManager.addRule(ruleset, rawIP)
|
||||
}
|
||||
// if ipset exists but it is not linked to rule for given direction
|
||||
// create new rule for direction and bind ipset to it later
|
||||
}
|
||||
}
|
||||
|
||||
ifaceKey := expr.MetaKeyIIFNAME
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
ifaceKey = expr.MetaKeyOIFNAME
|
||||
@@ -146,39 +212,47 @@ func (m *Manager) AddFiltering(
|
||||
})
|
||||
}
|
||||
|
||||
// don't use IP matching if IP is ip 0.0.0.0
|
||||
if s := ip.String(); s != "0.0.0.0" && s != "::" {
|
||||
// check if rawIP contains zeroed IPv4 0.0.0.0 or same IPv6 value
|
||||
// in that case not add IP match expression into the rule definition
|
||||
if !bytes.HasPrefix(anyIP, rawIP) {
|
||||
// source address position
|
||||
var adrLen, adrOffset uint32
|
||||
if ip.To4() == nil {
|
||||
adrLen = 16
|
||||
adrOffset = 8
|
||||
} else {
|
||||
adrLen = 4
|
||||
adrOffset = 12
|
||||
addrLen := uint32(len(rawIP))
|
||||
addrOffset := uint32(12)
|
||||
if addrLen == 16 {
|
||||
addrOffset = 8
|
||||
}
|
||||
|
||||
// change to destination address position if need
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
adrOffset += adrLen
|
||||
addrOffset += addrLen
|
||||
}
|
||||
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: adrOffset,
|
||||
Len: adrLen,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: add.AsSlice(),
|
||||
Offset: addrOffset,
|
||||
Len: addrLen,
|
||||
},
|
||||
)
|
||||
// add individual IP for match if no ipset defined
|
||||
if ipset == nil {
|
||||
expressions = append(expressions,
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: rawIP,
|
||||
},
|
||||
)
|
||||
} else {
|
||||
expressions = append(expressions,
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: ipsetName,
|
||||
SetID: ipset.ID,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if sPort != nil && len(sPort.Values) != 0 {
|
||||
@@ -219,39 +293,76 @@ func (m *Manager) AddFiltering(
|
||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
userData := []byte(strings.Join([]string{id, comment}, " "))
|
||||
userData := []byte(strings.Join([]string{rulesetID, comment}, " "))
|
||||
|
||||
_ = m.conn.InsertRule(&nftables.Rule{
|
||||
rule := m.rConn.InsertRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Position: 0,
|
||||
Exprs: expressions,
|
||||
UserData: userData,
|
||||
})
|
||||
|
||||
if err := m.conn.Flush(); err != nil {
|
||||
return nil, err
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush insert rule: %v", err)
|
||||
}
|
||||
|
||||
list, err := m.conn.GetRules(table, chain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset)
|
||||
return m.rulesetManager.addRule(ruleset, rawIP)
|
||||
}
|
||||
|
||||
// getRulesetID returns ruleset ID based on given parameters
|
||||
func (m *Manager) getRulesetID(
|
||||
ip net.IP,
|
||||
proto fw.Protocol,
|
||||
sPort *fw.Port,
|
||||
dPort *fw.Port,
|
||||
direction fw.RuleDirection,
|
||||
action fw.Action,
|
||||
ipsetName string,
|
||||
) string {
|
||||
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
|
||||
if sPort != nil {
|
||||
rulesetID += sPort.String()
|
||||
}
|
||||
rulesetID += ":"
|
||||
if dPort != nil {
|
||||
rulesetID += dPort.String()
|
||||
}
|
||||
rulesetID += ":"
|
||||
rulesetID += strconv.Itoa(int(action))
|
||||
if ipsetName == "" {
|
||||
return "ip:" + ip.String() + rulesetID
|
||||
}
|
||||
return "set:" + ipsetName + rulesetID
|
||||
}
|
||||
|
||||
// createSet in given table by name
|
||||
func (m *Manager) createSet(
|
||||
table *nftables.Table,
|
||||
rawIP []byte,
|
||||
name string,
|
||||
) (*nftables.Set, error) {
|
||||
keyType := nftables.TypeIPAddr
|
||||
if len(rawIP) == 16 {
|
||||
keyType = nftables.TypeIP6Addr
|
||||
}
|
||||
// else we create new ipset and continue creating rule
|
||||
ipset := &nftables.Set{
|
||||
Name: name,
|
||||
Table: table,
|
||||
Dynamic: true,
|
||||
KeyType: keyType,
|
||||
}
|
||||
|
||||
// Add the rule to the chain
|
||||
rule := &Rule{id: id}
|
||||
for _, r := range list {
|
||||
if bytes.Equal(r.UserData, userData) {
|
||||
rule.Rule = r
|
||||
break
|
||||
}
|
||||
}
|
||||
if rule.Rule == nil {
|
||||
return nil, fmt.Errorf("rule not found")
|
||||
if err := m.rConn.AddSet(ipset, nil); err != nil {
|
||||
return nil, fmt.Errorf("create set: %v", err)
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush created set: %v", err)
|
||||
}
|
||||
|
||||
return ipset, nil
|
||||
}
|
||||
|
||||
// chain returns the chain for the given IP address with specific settings
|
||||
@@ -315,7 +426,7 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
|
||||
}
|
||||
|
||||
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) {
|
||||
tables, err := m.conn.ListTablesOfFamily(family)
|
||||
tables, err := m.rConn.ListTablesOfFamily(family)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
@@ -326,7 +437,11 @@ func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables
|
||||
}
|
||||
}
|
||||
|
||||
return m.conn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4}), nil
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4})
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return table, nil
|
||||
}
|
||||
|
||||
func (m *Manager) createChainIfNotExists(
|
||||
@@ -341,7 +456,7 @@ func (m *Manager) createChainIfNotExists(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chains, err := m.conn.ListChainsOfTableFamily(family)
|
||||
chains, err := m.rConn.ListChainsOfTableFamily(family)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list of chains: %w", err)
|
||||
}
|
||||
@@ -362,7 +477,7 @@ func (m *Manager) createChainIfNotExists(
|
||||
Policy: &polAccept,
|
||||
}
|
||||
|
||||
chain = m.conn.AddChain(chain)
|
||||
chain = m.rConn.AddChain(chain)
|
||||
|
||||
ifaceKey := expr.MetaKeyIIFNAME
|
||||
shiftDSTAddr := 0
|
||||
@@ -429,7 +544,7 @@ func (m *Manager) createChainIfNotExists(
|
||||
)
|
||||
}
|
||||
|
||||
_ = m.conn.AddRule(&nftables.Rule{
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
@@ -444,12 +559,13 @@ func (m *Manager) createChainIfNotExists(
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||
}
|
||||
_ = m.conn.AddRule(&nftables.Rule{
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
})
|
||||
if err := m.conn.Flush(); err != nil {
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -458,16 +574,58 @@ func (m *Manager) createChainIfNotExists(
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
nativeRule, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
|
||||
if err := m.conn.DelRule(nativeRule.Rule); err != nil {
|
||||
return err
|
||||
if nativeRule.nftRule == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return m.conn.Flush()
|
||||
if nativeRule.nftSet != nil {
|
||||
// call twice of delete set element raises error
|
||||
// so we need to check if element is already removed
|
||||
key := fmt.Sprintf("%s:%v", nativeRule.nftSet.Name, nativeRule.ip)
|
||||
if _, ok := m.setRemovedIPs[key]; !ok {
|
||||
err := m.sConn.SetDeleteElements(nativeRule.nftSet, []nftables.SetElement{{Key: nativeRule.ip}})
|
||||
if err != nil {
|
||||
log.Errorf("delete elements for set %q: %v", nativeRule.nftSet.Name, err)
|
||||
}
|
||||
if err := m.sConn.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
m.setRemovedIPs[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if m.rulesetManager.deleteRule(nativeRule) {
|
||||
// deleteRule indicates that we still have IP in the ruleset
|
||||
// it means we should not remove the nftables rule but need to update set
|
||||
// so we prepare IP to be removed from set on the next flush call
|
||||
return nil
|
||||
}
|
||||
|
||||
// ruleset doesn't contain IP anymore (or contains only one), remove nft rule
|
||||
if err := m.rConn.DelRule(nativeRule.nftRule); err != nil {
|
||||
log.Errorf("failed to delete rule: %v", err)
|
||||
}
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
nativeRule.nftRule = nil
|
||||
|
||||
if nativeRule.nftSet != nil {
|
||||
if _, ok := m.setRemoved[nativeRule.nftSet.Name]; !ok {
|
||||
m.setRemoved[nativeRule.nftSet.Name] = nativeRule.nftSet
|
||||
}
|
||||
nativeRule.nftSet = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
@@ -475,27 +633,116 @@ func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
chains, err := m.conn.ListChains()
|
||||
chains, err := m.rConn.ListChains()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of chains: %w", err)
|
||||
}
|
||||
for _, c := range chains {
|
||||
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
|
||||
m.conn.DelChain(c)
|
||||
m.rConn.DelChain(c)
|
||||
}
|
||||
}
|
||||
|
||||
tables, err := m.conn.ListTables()
|
||||
tables, err := m.rConn.ListTables()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
for _, t := range tables {
|
||||
if t.Name == FilterTableName {
|
||||
m.conn.DelTable(t)
|
||||
m.rConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
return m.conn.Flush()
|
||||
return m.rConn.Flush()
|
||||
}
|
||||
|
||||
// Flush rule/chain/set operations from the buffer
|
||||
//
|
||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||
func (m *Manager) Flush() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if err := m.flushWithBackoff(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// set must be removed after flush rule changes
|
||||
// otherwise we will get error
|
||||
for _, s := range m.setRemoved {
|
||||
m.rConn.FlushSet(s)
|
||||
m.rConn.DelSet(s)
|
||||
}
|
||||
|
||||
if len(m.setRemoved) > 0 {
|
||||
if err := m.flushWithBackoff(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
m.setRemovedIPs = map[string]struct{}{}
|
||||
m.setRemoved = map[string]*nftables.Set{}
|
||||
|
||||
if err := m.refreshRuleHandles(m.tableIPv4, m.filterInputChainIPv4); err != nil {
|
||||
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.tableIPv4, m.filterOutputChainIPv4); err != nil {
|
||||
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.tableIPv6, m.filterInputChainIPv6); err != nil {
|
||||
log.Errorf("failed to refresh rule handles IPv6 input chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.tableIPv6, m.filterOutputChainIPv6); err != nil {
|
||||
log.Errorf("failed to refresh rule handles IPv6 output chain: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) flushWithBackoff() (err error) {
|
||||
backoff := 4
|
||||
backoffTime := 1000 * time.Millisecond
|
||||
for i := 0; ; i++ {
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
if !strings.Contains(err.Error(), "busy") {
|
||||
return
|
||||
}
|
||||
log.Error("failed to flush nftables, retrying...")
|
||||
if i == backoff-1 {
|
||||
return err
|
||||
}
|
||||
time.Sleep(backoffTime)
|
||||
backoffTime = backoffTime * 2
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chain) error {
|
||||
if table == nil || chain == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
list, err := m.rConn.GetRules(table, chain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, rule := range list {
|
||||
if len(rule.UserData) != 0 {
|
||||
if err := m.rulesetManager.setNftRuleHandle(rule); err != nil {
|
||||
log.Errorf("failed to set rule handle: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func encodePort(port fw.Port) []byte {
|
||||
|
||||
@@ -55,7 +55,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
defer func() {
|
||||
err = manager.Reset()
|
||||
@@ -75,11 +75,16 @@ func TestNftablesManager(t *testing.T) {
|
||||
fw.RuleDirectionIN,
|
||||
fw.ActionDrop,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
|
||||
// test expectations:
|
||||
// 1) regular rule
|
||||
// 2) "accept extra routed traffic rule" for the interface
|
||||
@@ -135,6 +140,9 @@ func TestNftablesManager(t *testing.T) {
|
||||
err = manager.DeleteRule(rule)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
// test expectations:
|
||||
@@ -167,7 +175,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
defer func() {
|
||||
if err := manager.Reset(); err != nil {
|
||||
@@ -181,13 +189,18 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
if i%100 == 0 {
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,11 +6,14 @@ import (
|
||||
|
||||
// Rule to handle management of rules
|
||||
type Rule struct {
|
||||
*nftables.Rule
|
||||
id string
|
||||
nftRule *nftables.Rule
|
||||
nftSet *nftables.Set
|
||||
|
||||
ruleID string
|
||||
ip []byte
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *Rule) GetRuleID() string {
|
||||
return r.id
|
||||
return r.ruleID
|
||||
}
|
||||
|
||||
115
client/firewall/nftables/ruleset_linux.go
Normal file
115
client/firewall/nftables/ruleset_linux.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/rs/xid"
|
||||
)
|
||||
|
||||
// nftRuleset links native firewall rule and ipset to ACL generated rules
|
||||
type nftRuleset struct {
|
||||
nftRule *nftables.Rule
|
||||
nftSet *nftables.Set
|
||||
issuedRules map[string]*Rule
|
||||
rulesetID string
|
||||
}
|
||||
|
||||
type rulesetManager struct {
|
||||
rulesets map[string]*nftRuleset
|
||||
|
||||
nftSetName2rulesetID map[string]string
|
||||
issuedRuleID2rulesetID map[string]string
|
||||
}
|
||||
|
||||
func newRuleManager() *rulesetManager {
|
||||
return &rulesetManager{
|
||||
rulesets: map[string]*nftRuleset{},
|
||||
|
||||
nftSetName2rulesetID: map[string]string{},
|
||||
issuedRuleID2rulesetID: map[string]string{},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rulesetManager) getRuleset(rulesetID string) (*nftRuleset, bool) {
|
||||
ruleset, ok := r.rulesets[rulesetID]
|
||||
return ruleset, ok
|
||||
}
|
||||
|
||||
func (r *rulesetManager) createRuleset(
|
||||
rulesetID string,
|
||||
nftRule *nftables.Rule,
|
||||
nftSet *nftables.Set,
|
||||
) *nftRuleset {
|
||||
ruleset := nftRuleset{
|
||||
rulesetID: rulesetID,
|
||||
nftRule: nftRule,
|
||||
nftSet: nftSet,
|
||||
issuedRules: map[string]*Rule{},
|
||||
}
|
||||
r.rulesets[ruleset.rulesetID] = &ruleset
|
||||
if nftSet != nil {
|
||||
r.nftSetName2rulesetID[nftSet.Name] = ruleset.rulesetID
|
||||
}
|
||||
return &ruleset
|
||||
}
|
||||
|
||||
func (r *rulesetManager) addRule(
|
||||
ruleset *nftRuleset,
|
||||
ip []byte,
|
||||
) (*Rule, error) {
|
||||
if _, ok := r.rulesets[ruleset.rulesetID]; !ok {
|
||||
return nil, fmt.Errorf("ruleset not found")
|
||||
}
|
||||
|
||||
rule := Rule{
|
||||
nftRule: ruleset.nftRule,
|
||||
nftSet: ruleset.nftSet,
|
||||
ruleID: xid.New().String(),
|
||||
ip: ip,
|
||||
}
|
||||
|
||||
ruleset.issuedRules[rule.ruleID] = &rule
|
||||
r.issuedRuleID2rulesetID[rule.ruleID] = ruleset.rulesetID
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
|
||||
// deleteRule from ruleset and returns true if contains other rules
|
||||
func (r *rulesetManager) deleteRule(rule *Rule) bool {
|
||||
rulesetID, ok := r.issuedRuleID2rulesetID[rule.ruleID]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
ruleset := r.rulesets[rulesetID]
|
||||
if ruleset.nftRule == nil {
|
||||
return false
|
||||
}
|
||||
delete(r.issuedRuleID2rulesetID, rule.ruleID)
|
||||
delete(ruleset.issuedRules, rule.ruleID)
|
||||
|
||||
if len(ruleset.issuedRules) == 0 {
|
||||
delete(r.rulesets, ruleset.rulesetID)
|
||||
if rule.nftSet != nil {
|
||||
delete(r.nftSetName2rulesetID, rule.nftSet.Name)
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// setNftRuleHandle finds rule by userdata which contains rulesetID and updates it's handle number
|
||||
//
|
||||
// This is important to do, because after we add rule to the nftables we can't update it until
|
||||
// we set correct handle value to it.
|
||||
func (r *rulesetManager) setNftRuleHandle(nftRule *nftables.Rule) error {
|
||||
split := bytes.Split(nftRule.UserData, []byte(" "))
|
||||
ruleset, ok := r.rulesets[string(split[0])]
|
||||
if !ok {
|
||||
return fmt.Errorf("ruleset not found")
|
||||
}
|
||||
*ruleset.nftRule = *nftRule
|
||||
return nil
|
||||
}
|
||||
122
client/firewall/nftables/ruleset_linux_test.go
Normal file
122
client/firewall/nftables/ruleset_linux_test.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRulesetManager_createRuleset(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{
|
||||
UserData: []byte(rulesetID),
|
||||
}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||
require.NotNil(t, ruleset, "createRuleset() failed")
|
||||
require.Equal(t, ruleset.rulesetID, rulesetID, "rulesetID is incorrect")
|
||||
require.Equal(t, ruleset.nftRule, &nftRule, "nftRule is incorrect")
|
||||
}
|
||||
|
||||
func TestRulesetManager_addRule(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||
|
||||
// Add a rule to the ruleset.
|
||||
ip := []byte("192.168.1.1")
|
||||
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||
require.NoError(t, err, "addRule() failed")
|
||||
require.NotNil(t, rule, "rule should not be nil")
|
||||
require.NotEqual(t, rule.ruleID, "ruleID is empty")
|
||||
require.EqualValues(t, rule.ip, ip, "ip is incorrect")
|
||||
require.Contains(t, ruleset.issuedRules, rule.ruleID, "ruleID already exists in ruleset")
|
||||
require.Contains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "ruleID already exists in ruleset manager")
|
||||
|
||||
ruleset2 := &nftRuleset{
|
||||
rulesetID: "ruleset-2",
|
||||
}
|
||||
_, err = rulesetManager.addRule(ruleset2, ip)
|
||||
require.Error(t, err, "addRule() should have failed")
|
||||
}
|
||||
|
||||
func TestRulesetManager_deleteRule(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||
|
||||
// Add a rule to the ruleset.
|
||||
ip := []byte("192.168.1.1")
|
||||
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||
require.NoError(t, err, "addRule() failed")
|
||||
require.NotNil(t, rule, "rule should not be nil")
|
||||
|
||||
ip2 := []byte("192.168.1.1")
|
||||
rule2, err := rulesetManager.addRule(ruleset, ip2)
|
||||
require.NoError(t, err, "addRule() failed")
|
||||
require.NotNil(t, rule2, "rule should not be nil")
|
||||
|
||||
hasNext := rulesetManager.deleteRule(rule)
|
||||
require.True(t, hasNext, "deleteRule() should have returned true")
|
||||
|
||||
// Check that the rule is no longer in the manager.
|
||||
require.NotContains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "rule should have been deleted")
|
||||
|
||||
hasNext = rulesetManager.deleteRule(rule2)
|
||||
require.False(t, hasNext, "deleteRule() should have returned false")
|
||||
}
|
||||
|
||||
func TestRulesetManager_setNftRuleHandle(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||
// Add a rule to the ruleset.
|
||||
ip := []byte("192.168.0.1")
|
||||
|
||||
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||
require.NoError(t, err, "addRule() failed")
|
||||
require.NotNil(t, rule, "rule should not be nil")
|
||||
|
||||
nftRuleCopy := nftRule
|
||||
nftRuleCopy.Handle = 2
|
||||
nftRuleCopy.UserData = []byte(rulesetID)
|
||||
err = rulesetManager.setNftRuleHandle(&nftRuleCopy)
|
||||
require.NoError(t, err, "setNftRuleHandle() failed")
|
||||
// check correct work with references
|
||||
require.Equal(t, nftRule.Handle, uint64(2), "nftRule.Handle is incorrect")
|
||||
}
|
||||
|
||||
func TestRulesetManager_getRuleset(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{}
|
||||
nftSet := nftables.Set{
|
||||
ID: 2,
|
||||
}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, &nftSet)
|
||||
require.NotNil(t, ruleset, "createRuleset() failed")
|
||||
|
||||
find, ok := rulesetManager.getRuleset(rulesetID)
|
||||
require.True(t, ok, "getRuleset() failed")
|
||||
require.Equal(t, ruleset, find, "getRulesetBySetID() failed")
|
||||
|
||||
_, ok = rulesetManager.getRuleset("does-not-exist")
|
||||
require.False(t, ok, "getRuleset() failed")
|
||||
}
|
||||
@@ -21,11 +21,13 @@ type IFaceMapper interface {
|
||||
SetFilter(iface.PacketFilter) error
|
||||
}
|
||||
|
||||
// RuleSet is a set of rules grouped by a string key
|
||||
type RuleSet map[string]Rule
|
||||
|
||||
// Manager userspace firewall manager
|
||||
type Manager struct {
|
||||
outgoingRules []Rule
|
||||
incomingRules []Rule
|
||||
rulesIndex map[string]int
|
||||
outgoingRules map[string]RuleSet
|
||||
incomingRules map[string]RuleSet
|
||||
wgNetwork *net.IPNet
|
||||
decoders sync.Pool
|
||||
|
||||
@@ -48,7 +50,6 @@ type decoder struct {
|
||||
// Create userspace firewall manager constructor
|
||||
func Create(iface IFaceMapper) (*Manager, error) {
|
||||
m := &Manager{
|
||||
rulesIndex: make(map[string]int),
|
||||
decoders: sync.Pool{
|
||||
New: func() any {
|
||||
d := &decoder{
|
||||
@@ -62,6 +63,8 @@ func Create(iface IFaceMapper) (*Manager, error) {
|
||||
return d
|
||||
},
|
||||
},
|
||||
outgoingRules: make(map[string]RuleSet),
|
||||
incomingRules: make(map[string]RuleSet),
|
||||
}
|
||||
|
||||
if err := iface.SetFilter(m); err != nil {
|
||||
@@ -81,6 +84,7 @@ func (m *Manager) AddFiltering(
|
||||
dPort *fw.Port,
|
||||
direction fw.RuleDirection,
|
||||
action fw.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) (fw.Rule, error) {
|
||||
r := Rule{
|
||||
@@ -124,15 +128,17 @@ func (m *Manager) AddFiltering(
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
var p int
|
||||
if direction == fw.RuleDirectionIN {
|
||||
m.incomingRules = append(m.incomingRules, r)
|
||||
p = len(m.incomingRules) - 1
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
} else {
|
||||
m.outgoingRules = append(m.outgoingRules, r)
|
||||
p = len(m.outgoingRules) - 1
|
||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||
m.outgoingRules[r.ip.String()] = make(RuleSet)
|
||||
}
|
||||
m.outgoingRules[r.ip.String()][r.id] = r
|
||||
}
|
||||
m.rulesIndex[r.id] = p
|
||||
m.mutex.Unlock()
|
||||
|
||||
return &r, nil
|
||||
@@ -148,24 +154,20 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||
}
|
||||
|
||||
p, ok := m.rulesIndex[r.id]
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.rulesIndex, r.id)
|
||||
|
||||
var toUpdate []Rule
|
||||
if r.direction == fw.RuleDirectionIN {
|
||||
m.incomingRules = append(m.incomingRules[:p], m.incomingRules[p+1:]...)
|
||||
toUpdate = m.incomingRules
|
||||
_, ok := m.incomingRules[r.ip.String()][r.id]
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.incomingRules[r.ip.String()], r.id)
|
||||
} else {
|
||||
m.outgoingRules = append(m.outgoingRules[:p], m.outgoingRules[p+1:]...)
|
||||
toUpdate = m.outgoingRules
|
||||
_, ok := m.outgoingRules[r.ip.String()][r.id]
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.outgoingRules[r.ip.String()], r.id)
|
||||
}
|
||||
|
||||
for i := 0; i < len(toUpdate); i++ {
|
||||
m.rulesIndex[toUpdate[i].id] = i
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -174,13 +176,15 @@ func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = m.outgoingRules[:0]
|
||||
m.incomingRules = m.incomingRules[:0]
|
||||
m.rulesIndex = make(map[string]int)
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// DropOutgoing filter outgoing packets
|
||||
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
||||
return m.dropFilter(packetData, m.outgoingRules, false)
|
||||
@@ -192,7 +196,7 @@ func (m *Manager) DropIncoming(packetData []byte) bool {
|
||||
}
|
||||
|
||||
// dropFilter imlements same logic for booth direction of the traffic
|
||||
func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket bool) bool {
|
||||
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
@@ -224,37 +228,49 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
|
||||
log.Errorf("unknown layer: %v", d.decoded[0])
|
||||
return true
|
||||
}
|
||||
payloadLayer := d.decoded[1]
|
||||
|
||||
// check if IP address match by IP
|
||||
var ip net.IP
|
||||
switch ipLayer {
|
||||
case layers.LayerTypeIPv4:
|
||||
if isIncomingPacket {
|
||||
ip = d.ip4.SrcIP
|
||||
} else {
|
||||
ip = d.ip4.DstIP
|
||||
}
|
||||
case layers.LayerTypeIPv6:
|
||||
if isIncomingPacket {
|
||||
ip = d.ip6.SrcIP
|
||||
} else {
|
||||
ip = d.ip6.DstIP
|
||||
}
|
||||
}
|
||||
|
||||
filter, ok := validateRule(ip, packetData, rules[ip.String()], d)
|
||||
if ok {
|
||||
return filter
|
||||
}
|
||||
filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d)
|
||||
if ok {
|
||||
return filter
|
||||
}
|
||||
filter, ok = validateRule(ip, packetData, rules["::"], d)
|
||||
if ok {
|
||||
return filter
|
||||
}
|
||||
|
||||
// default policy is DROP ALL
|
||||
return true
|
||||
}
|
||||
|
||||
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
|
||||
payloadLayer := d.decoded[1]
|
||||
for _, rule := range rules {
|
||||
if rule.matchByIP {
|
||||
switch ipLayer {
|
||||
case layers.LayerTypeIPv4:
|
||||
if isIncomingPacket {
|
||||
if !d.ip4.SrcIP.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if !d.ip4.DstIP.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
case layers.LayerTypeIPv6:
|
||||
if isIncomingPacket {
|
||||
if !d.ip6.SrcIP.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if !d.ip6.DstIP.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
if rule.matchByIP && !ip.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
|
||||
if rule.protoLayer == layerTypeAll {
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
}
|
||||
|
||||
if payloadLayer != rule.protoLayer {
|
||||
@@ -264,38 +280,36 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
|
||||
switch payloadLayer {
|
||||
case layers.LayerTypeTCP:
|
||||
if rule.sPort == 0 && rule.dPort == 0 {
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
}
|
||||
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
}
|
||||
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
}
|
||||
case layers.LayerTypeUDP:
|
||||
// if rule has UDP hook (and if we are here we match this rule)
|
||||
// we ignore rule.drop and call this hook
|
||||
if rule.udpHook != nil {
|
||||
return rule.udpHook(packetData)
|
||||
return rule.udpHook(packetData), true
|
||||
}
|
||||
|
||||
if rule.sPort == 0 && rule.dPort == 0 {
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
}
|
||||
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
}
|
||||
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
}
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
}
|
||||
}
|
||||
|
||||
// default policy is DROP ALL
|
||||
return true
|
||||
return false, false
|
||||
}
|
||||
|
||||
// SetNetwork of the wireguard interface to which filtering applied
|
||||
@@ -325,19 +339,19 @@ func (m *Manager) AddUDPPacketHook(
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
var toUpdate []Rule
|
||||
if in {
|
||||
r.direction = fw.RuleDirectionIN
|
||||
m.incomingRules = append([]Rule{r}, m.incomingRules...)
|
||||
toUpdate = m.incomingRules
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(map[string]Rule)
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
} else {
|
||||
m.outgoingRules = append([]Rule{r}, m.outgoingRules...)
|
||||
toUpdate = m.outgoingRules
|
||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||
m.outgoingRules[r.ip.String()] = make(map[string]Rule)
|
||||
}
|
||||
m.outgoingRules[r.ip.String()][r.id] = r
|
||||
}
|
||||
|
||||
for i := range toUpdate {
|
||||
m.rulesIndex[toUpdate[i].id] = i
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
|
||||
return r.id
|
||||
@@ -345,14 +359,18 @@ func (m *Manager) AddUDPPacketHook(
|
||||
|
||||
// RemovePacketHook removes packet hook by given ID
|
||||
func (m *Manager) RemovePacketHook(hookID string) error {
|
||||
for _, r := range m.incomingRules {
|
||||
if r.id == hookID {
|
||||
return m.DeleteRule(&r)
|
||||
for _, arr := range m.incomingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
return m.DeleteRule(&r)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, r := range m.outgoingRules {
|
||||
if r.id == hookID {
|
||||
return m.DeleteRule(&r)
|
||||
for _, arr := range m.outgoingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
return m.DeleteRule(&r)
|
||||
}
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("hook with given id not found")
|
||||
|
||||
@@ -63,7 +63,7 @@ func TestManagerAddFiltering(t *testing.T) {
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -98,7 +98,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -111,7 +111,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
action = fw.ActionDrop
|
||||
comment = "Test rule 2"
|
||||
|
||||
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
||||
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -123,8 +123,8 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if idx, ok := m.rulesIndex[rule2.GetRuleID()]; !ok || len(m.incomingRules) != 1 || idx != 0 {
|
||||
t.Errorf("rule2 is not in the rulesIndex")
|
||||
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; !ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
}
|
||||
|
||||
err = m.DeleteRule(rule2)
|
||||
@@ -133,8 +133,8 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(m.rulesIndex) != 0 || len(m.incomingRules) != 0 {
|
||||
t.Errorf("rule1 still in the rulesIndex")
|
||||
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,26 +169,29 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager := &Manager{
|
||||
incomingRules: []Rule{},
|
||||
outgoingRules: []Rule{},
|
||||
rulesIndex: make(map[string]int),
|
||||
incomingRules: map[string]RuleSet{},
|
||||
outgoingRules: map[string]RuleSet{},
|
||||
}
|
||||
|
||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||
|
||||
var addedRule Rule
|
||||
if tt.in {
|
||||
if len(manager.incomingRules) != 1 {
|
||||
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||
return
|
||||
}
|
||||
addedRule = manager.incomingRules[0]
|
||||
for _, rule := range manager.incomingRules[tt.ip.String()] {
|
||||
addedRule = rule
|
||||
}
|
||||
} else {
|
||||
if len(manager.outgoingRules) != 1 {
|
||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
||||
return
|
||||
}
|
||||
addedRule = manager.outgoingRules[0]
|
||||
for _, rule := range manager.outgoingRules[tt.ip.String()] {
|
||||
addedRule = rule
|
||||
}
|
||||
}
|
||||
|
||||
if !tt.ip.Equal(addedRule.ip) {
|
||||
@@ -211,17 +214,6 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
t.Errorf("expected udpHook to be set")
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure rulesIndex is correctly updated
|
||||
index, ok := manager.rulesIndex[addedRule.id]
|
||||
if !ok {
|
||||
t.Errorf("expected rule to be in rulesIndex")
|
||||
return
|
||||
}
|
||||
if index != 0 {
|
||||
t.Errorf("expected rule index to be 0, got %d", index)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -244,7 +236,7 @@ func TestManagerReset(t *testing.T) {
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
||||
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -256,7 +248,7 @@ func TestManagerReset(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(m.rulesIndex) != 0 || len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
|
||||
if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
|
||||
t.Errorf("rules is not empty")
|
||||
}
|
||||
}
|
||||
@@ -282,7 +274,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
action := fw.ActionAccept
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, comment)
|
||||
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -346,10 +338,12 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
|
||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||
found := false
|
||||
for _, rule := range manager.outgoingRules {
|
||||
if rule.id == hookID {
|
||||
found = true
|
||||
break
|
||||
for _, arr := range manager.outgoingRules {
|
||||
for _, rule := range arr {
|
||||
if rule.id == hookID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -364,9 +358,11 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
}
|
||||
|
||||
// Assert the hook is removed by checking it in the manager's outgoing rules
|
||||
for _, rule := range manager.outgoingRules {
|
||||
if rule.id == hookID {
|
||||
t.Fatalf("The hook was not removed properly.")
|
||||
for _, arr := range manager.outgoingRules {
|
||||
for _, rule := range arr {
|
||||
if rule.id == hookID {
|
||||
t.Fatalf("The hook was not removed properly.")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -394,9 +390,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
@@ -33,9 +33,22 @@ type Manager interface {
|
||||
|
||||
// DefaultManager uses firewall manager to handle
|
||||
type DefaultManager struct {
|
||||
manager firewall.Manager
|
||||
rulesPairs map[string][]firewall.Rule
|
||||
mutex sync.Mutex
|
||||
manager firewall.Manager
|
||||
ipsetCounter int
|
||||
rulesPairs map[string][]firewall.Rule
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
type ipsetInfo struct {
|
||||
name string
|
||||
ipCount int
|
||||
}
|
||||
|
||||
func newDefaultManager(fm firewall.Manager) *DefaultManager {
|
||||
return &DefaultManager{
|
||||
manager: fm,
|
||||
rulesPairs: make(map[string][]firewall.Rule),
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
|
||||
@@ -61,6 +74,12 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := d.manager.Flush(); err != nil {
|
||||
log.Error("failed to flush firewall rules: ", err)
|
||||
}
|
||||
}()
|
||||
|
||||
rules, squashedProtocols := d.squashAcceptRules(networkMap)
|
||||
|
||||
enableSSH := (networkMap.PeerConfig != nil &&
|
||||
@@ -108,8 +127,32 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
|
||||
applyFailed := false
|
||||
newRulePairs := make(map[string][]firewall.Rule)
|
||||
ipsetByRuleSelectors := make(map[string]*ipsetInfo)
|
||||
|
||||
// calculate which IP's can be grouped in by which ipset
|
||||
// to do that we use rule selector (which is just rule properties without IP's)
|
||||
for _, r := range rules {
|
||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r)
|
||||
selector := d.getRuleGroupingSelector(r)
|
||||
ipset, ok := ipsetByRuleSelectors[selector]
|
||||
if !ok {
|
||||
ipset = &ipsetInfo{}
|
||||
}
|
||||
|
||||
ipset.ipCount++
|
||||
ipsetByRuleSelectors[selector] = ipset
|
||||
}
|
||||
|
||||
for _, r := range rules {
|
||||
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
||||
// it's IP address can be used in the ipset for firewall manager which supports it
|
||||
ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)]
|
||||
ipsetName := ""
|
||||
if ipset.name == "" {
|
||||
d.ipsetCounter++
|
||||
ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter)
|
||||
}
|
||||
ipsetName = ipset.name
|
||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
||||
if err != nil {
|
||||
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
||||
applyFailed = true
|
||||
@@ -154,7 +197,10 @@ func (d *DefaultManager) Stop() {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) (string, []firewall.Rule, error) {
|
||||
func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
r *mgmProto.FirewallRule,
|
||||
ipsetName string,
|
||||
) (string, []firewall.Rule, error) {
|
||||
ip := net.ParseIP(r.PeerIP)
|
||||
if ip == nil {
|
||||
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
@@ -190,9 +236,9 @@ func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) (stri
|
||||
var err error
|
||||
switch r.Direction {
|
||||
case mgmProto.FirewallRule_IN:
|
||||
rules, err = d.addInRules(ip, protocol, port, action, "")
|
||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
||||
case mgmProto.FirewallRule_OUT:
|
||||
rules, err = d.addOutRules(ip, protocol, port, action, "")
|
||||
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
||||
default:
|
||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||
}
|
||||
@@ -205,9 +251,17 @@ func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) (stri
|
||||
return ruleID, rules, nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) {
|
||||
func (d *DefaultManager) addInRules(
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var rules []firewall.Rule
|
||||
rule, err := d.manager.AddFiltering(ip, protocol, nil, port, firewall.RuleDirectionIN, action, comment)
|
||||
rule, err := d.manager.AddFiltering(
|
||||
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
@@ -217,7 +271,8 @@ func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
rule, err = d.manager.AddFiltering(ip, protocol, port, nil, firewall.RuleDirectionOUT, action, comment)
|
||||
rule, err = d.manager.AddFiltering(
|
||||
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
@@ -225,9 +280,17 @@ func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port
|
||||
return append(rules, rule), nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) {
|
||||
func (d *DefaultManager) addOutRules(
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var rules []firewall.Rule
|
||||
rule, err := d.manager.AddFiltering(ip, protocol, nil, port, firewall.RuleDirectionOUT, action, comment)
|
||||
rule, err := d.manager.AddFiltering(
|
||||
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
@@ -237,7 +300,8 @@ func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
rule, err = d.manager.AddFiltering(ip, protocol, port, nil, firewall.RuleDirectionIN, action, comment)
|
||||
rule, err = d.manager.AddFiltering(
|
||||
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
@@ -282,6 +346,10 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
in := protoMatch{}
|
||||
out := protoMatch{}
|
||||
|
||||
// trace which type of protocols was squashed
|
||||
squashedRules := []*mgmProto.FirewallRule{}
|
||||
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
|
||||
|
||||
// this function we use to do calculation, can we squash the rules by protocol or not.
|
||||
// We summ amount of Peers IP for given protocol we found in original rules list.
|
||||
// But we zeroed the IP's for protocol if:
|
||||
@@ -298,12 +366,22 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
if _, ok := protocols[r.Protocol]; !ok {
|
||||
protocols[r.Protocol] = map[string]int{}
|
||||
}
|
||||
match := protocols[r.Protocol]
|
||||
|
||||
if _, ok := match[r.PeerIP]; ok {
|
||||
// special case, when we recieve this all network IP address
|
||||
// it means that rules for that protocol was already optimized on the
|
||||
// management side
|
||||
if r.PeerIP == "0.0.0.0" {
|
||||
squashedRules = append(squashedRules, r)
|
||||
squashedProtocols[r.Protocol] = struct{}{}
|
||||
return
|
||||
}
|
||||
match[r.PeerIP] = i
|
||||
|
||||
ipset := protocols[r.Protocol]
|
||||
|
||||
if _, ok := ipset[r.PeerIP]; ok {
|
||||
return
|
||||
}
|
||||
ipset[r.PeerIP] = i
|
||||
}
|
||||
|
||||
for i, r := range networkMap.FirewallRules {
|
||||
@@ -324,9 +402,6 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
mgmProto.FirewallRule_UDP,
|
||||
}
|
||||
|
||||
// trace which type of protocols was squashed
|
||||
squashedRules := []*mgmProto.FirewallRule{}
|
||||
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
|
||||
squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
|
||||
for _, protocol := range protocolOrders {
|
||||
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
|
||||
@@ -382,6 +457,11 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
return append(rules, squashedRules...), squashedProtocols
|
||||
}
|
||||
|
||||
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
||||
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
||||
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
|
||||
}
|
||||
|
||||
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol {
|
||||
switch protocol {
|
||||
case mgmProto.FirewallRule_TCP:
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
)
|
||||
|
||||
@@ -18,10 +17,7 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &DefaultManager{
|
||||
manager: fm,
|
||||
rulesPairs: make(map[string][]firewall.Rule),
|
||||
}, nil
|
||||
return newDefaultManager(fm), nil
|
||||
}
|
||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
@@ -29,8 +29,5 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
return &DefaultManager{
|
||||
manager: fm,
|
||||
rulesPairs: make(map[string][]firewall.Rule),
|
||||
}, nil
|
||||
return newDefaultManager(fm), nil
|
||||
}
|
||||
|
||||
202
client/internal/auth/device_flow.go
Normal file
202
client/internal/auth/device_flow.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HostedGrantType grant type for device flow on Hosted
|
||||
const (
|
||||
HostedGrantType = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
)
|
||||
|
||||
var _ OAuthFlow = &DeviceAuthorizationFlow{}
|
||||
|
||||
// DeviceAuthorizationFlow implements the OAuthFlow interface,
|
||||
// for the Device Authorization Flow.
|
||||
type DeviceAuthorizationFlow struct {
|
||||
providerConfig internal.DeviceAuthProviderConfig
|
||||
|
||||
HTTPClient HTTPClient
|
||||
}
|
||||
|
||||
// RequestDeviceCodePayload used for request device code payload for auth0
|
||||
type RequestDeviceCodePayload struct {
|
||||
Audience string `json:"audience"`
|
||||
ClientID string `json:"client_id"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
// TokenRequestPayload used for requesting the auth0 token
|
||||
type TokenRequestPayload struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
DeviceCode string `json:"device_code,omitempty"`
|
||||
ClientID string `json:"client_id"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
}
|
||||
|
||||
// TokenRequestResponse used for parsing Hosted token's response
|
||||
type TokenRequestResponse struct {
|
||||
Error string `json:"error"`
|
||||
ErrorDescription string `json:"error_description"`
|
||||
TokenInfo
|
||||
}
|
||||
|
||||
// NewDeviceAuthorizationFlow returns device authorization flow client
|
||||
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
return &DeviceAuthorizationFlow{
|
||||
providerConfig: config,
|
||||
HTTPClient: httpClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetClientID returns the provider client id
|
||||
func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
|
||||
return d.providerConfig.ClientID
|
||||
}
|
||||
|
||||
// RequestAuthInfo requests a device code login flow information from Hosted
|
||||
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
||||
form := url.Values{}
|
||||
form.Add("client_id", d.providerConfig.ClientID)
|
||||
form.Add("audience", d.providerConfig.Audience)
|
||||
form.Add("scope", d.providerConfig.Scope)
|
||||
req, err := http.NewRequest("POST", d.providerConfig.DeviceAuthEndpoint,
|
||||
strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return AuthFlowInfo{}, fmt.Errorf("creating request failed with error: %v", err)
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
res, err := d.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return AuthFlowInfo{}, fmt.Errorf("doing request failed with error: %v", err)
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return AuthFlowInfo{}, fmt.Errorf("reading body failed with error: %v", err)
|
||||
}
|
||||
|
||||
if res.StatusCode != 200 {
|
||||
return AuthFlowInfo{}, fmt.Errorf("request device code returned status %d error: %s", res.StatusCode, string(body))
|
||||
}
|
||||
|
||||
deviceCode := AuthFlowInfo{}
|
||||
err = json.Unmarshal(body, &deviceCode)
|
||||
if err != nil {
|
||||
return AuthFlowInfo{}, fmt.Errorf("unmarshaling response failed with error: %v", err)
|
||||
}
|
||||
|
||||
// Fallback to the verification_uri if the IdP doesn't support verification_uri_complete
|
||||
if deviceCode.VerificationURIComplete == "" {
|
||||
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
|
||||
}
|
||||
|
||||
return deviceCode, err
|
||||
}
|
||||
|
||||
func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) {
|
||||
form := url.Values{}
|
||||
form.Add("client_id", d.providerConfig.ClientID)
|
||||
form.Add("grant_type", HostedGrantType)
|
||||
form.Add("device_code", info.DeviceCode)
|
||||
|
||||
req, err := http.NewRequest("POST", d.providerConfig.TokenEndpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return TokenRequestResponse{}, fmt.Errorf("failed to create request access token: %v", err)
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
res, err := d.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return TokenRequestResponse{}, fmt.Errorf("failed to request access token with error: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := res.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return TokenRequestResponse{}, fmt.Errorf("failed reading access token response body with error: %v", err)
|
||||
}
|
||||
|
||||
if res.StatusCode > 499 {
|
||||
return TokenRequestResponse{}, fmt.Errorf("access token response returned code: %s", string(body))
|
||||
}
|
||||
|
||||
tokenResponse := TokenRequestResponse{}
|
||||
err = json.Unmarshal(body, &tokenResponse)
|
||||
if err != nil {
|
||||
return TokenRequestResponse{}, fmt.Errorf("parsing token response failed with error: %v", err)
|
||||
}
|
||||
|
||||
return tokenResponse, nil
|
||||
}
|
||||
|
||||
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
||||
// it retrieves the access token from Hosted's endpoint and validates it before returning
|
||||
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
||||
interval := time.Duration(info.Interval) * time.Second
|
||||
ticker := time.NewTicker(interval)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return TokenInfo{}, ctx.Err()
|
||||
case <-ticker.C:
|
||||
|
||||
tokenResponse, err := d.requestToken(info)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("parsing token response failed with error: %v", err)
|
||||
}
|
||||
|
||||
if tokenResponse.Error != "" {
|
||||
if tokenResponse.Error == "authorization_pending" {
|
||||
continue
|
||||
} else if tokenResponse.Error == "slow_down" {
|
||||
interval = interval + (3 * time.Second)
|
||||
ticker.Reset(interval)
|
||||
continue
|
||||
}
|
||||
|
||||
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
|
||||
}
|
||||
|
||||
tokenInfo := TokenInfo{
|
||||
AccessToken: tokenResponse.AccessToken,
|
||||
TokenType: tokenResponse.TokenType,
|
||||
RefreshToken: tokenResponse.RefreshToken,
|
||||
IDToken: tokenResponse.IDToken,
|
||||
ExpiresIn: tokenResponse.ExpiresIn,
|
||||
UseIDToken: d.providerConfig.UseIDToken,
|
||||
}
|
||||
|
||||
err = isValidAccessToken(tokenInfo.GetTokenToUse(), d.providerConfig.Audience)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
||||
}
|
||||
|
||||
return tokenInfo, err
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,17 +1,17 @@
|
||||
package internal
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/stretchr/testify/require"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockHTTPClient struct {
|
||||
@@ -53,7 +53,7 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
||||
testingErrFunc require.ErrorAssertionFunc
|
||||
expectedErrorMSG string
|
||||
testingFunc require.ComparisonAssertionFunc
|
||||
expectedOut DeviceAuthInfo
|
||||
expectedOut AuthFlowInfo
|
||||
expectedMSG string
|
||||
expectPayload string
|
||||
}
|
||||
@@ -92,7 +92,7 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
||||
testingFunc: require.EqualValues,
|
||||
expectPayload: expectPayload,
|
||||
}
|
||||
testCase4Out := DeviceAuthInfo{ExpiresIn: 10}
|
||||
testCase4Out := AuthFlowInfo{ExpiresIn: 10}
|
||||
testCase4 := test{
|
||||
name: "Got Device Code",
|
||||
inputResBody: fmt.Sprintf("{\"expires_in\":%d}", testCase4Out.ExpiresIn),
|
||||
@@ -113,8 +113,8 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
||||
err: testCase.inputReqError,
|
||||
}
|
||||
|
||||
hosted := Hosted{
|
||||
providerConfig: ProviderConfig{
|
||||
deviceFlow := &DeviceAuthorizationFlow{
|
||||
providerConfig: internal.DeviceAuthProviderConfig{
|
||||
Audience: expectedAudience,
|
||||
ClientID: expectedClientID,
|
||||
Scope: expectedScope,
|
||||
@@ -125,7 +125,7 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
||||
HTTPClient: &httpClient,
|
||||
}
|
||||
|
||||
authInfo, err := hosted.RequestDeviceCode(context.TODO())
|
||||
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
|
||||
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
||||
|
||||
require.EqualValues(t, expectPayload, httpClient.reqBody, "payload should match")
|
||||
@@ -145,7 +145,7 @@ func TestHosted_WaitToken(t *testing.T) {
|
||||
inputMaxReqs int
|
||||
inputCountResBody string
|
||||
inputTimeout time.Duration
|
||||
inputInfo DeviceAuthInfo
|
||||
inputInfo AuthFlowInfo
|
||||
inputAudience string
|
||||
testingErrFunc require.ErrorAssertionFunc
|
||||
expectedErrorMSG string
|
||||
@@ -155,7 +155,7 @@ func TestHosted_WaitToken(t *testing.T) {
|
||||
expectPayload string
|
||||
}
|
||||
|
||||
defaultInfo := DeviceAuthInfo{
|
||||
defaultInfo := AuthFlowInfo{
|
||||
DeviceCode: "test",
|
||||
ExpiresIn: 10,
|
||||
Interval: 1,
|
||||
@@ -278,8 +278,8 @@ func TestHosted_WaitToken(t *testing.T) {
|
||||
countResBody: testCase.inputCountResBody,
|
||||
}
|
||||
|
||||
hosted := Hosted{
|
||||
providerConfig: ProviderConfig{
|
||||
deviceFlow := DeviceAuthorizationFlow{
|
||||
providerConfig: internal.DeviceAuthProviderConfig{
|
||||
Audience: testCase.inputAudience,
|
||||
ClientID: clientID,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
@@ -287,11 +287,12 @@ func TestHosted_WaitToken(t *testing.T) {
|
||||
Scope: "openid",
|
||||
UseIDToken: false,
|
||||
},
|
||||
HTTPClient: &httpClient}
|
||||
HTTPClient: &httpClient,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
||||
defer cancel()
|
||||
tokenInfo, err := hosted.WaitToken(ctx, testCase.inputInfo)
|
||||
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)
|
||||
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
||||
|
||||
require.EqualValues(t, testCase.expectPayload, httpClient.reqBody, "payload should match")
|
||||
90
client/internal/auth/oauth.go
Normal file
90
client/internal/auth/oauth.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
// OAuthFlow represents an interface for authorization using different OAuth 2.0 flows
|
||||
type OAuthFlow interface {
|
||||
RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error)
|
||||
WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error)
|
||||
GetClientID(ctx context.Context) string
|
||||
}
|
||||
|
||||
// HTTPClient http client interface for API calls
|
||||
type HTTPClient interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
// AuthFlowInfo holds information for the OAuth 2.0 authorization flow
|
||||
type AuthFlowInfo struct {
|
||||
DeviceCode string `json:"device_code"`
|
||||
UserCode string `json:"user_code"`
|
||||
VerificationURI string `json:"verification_uri"`
|
||||
VerificationURIComplete string `json:"verification_uri_complete"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Interval int `json:"interval"`
|
||||
}
|
||||
|
||||
// Claims used when validating the access token
|
||||
type Claims struct {
|
||||
Audience interface{} `json:"aud"`
|
||||
}
|
||||
|
||||
// TokenInfo holds information of issued access token
|
||||
type TokenInfo struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
UseIDToken bool `json:"-"`
|
||||
}
|
||||
|
||||
// GetTokenToUse returns either the access or id token based on UseIDToken field
|
||||
func (t TokenInfo) GetTokenToUse() string {
|
||||
if t.UseIDToken {
|
||||
return t.IDToken
|
||||
}
|
||||
return t.AccessToken
|
||||
}
|
||||
|
||||
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration.
|
||||
func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
||||
log.Debug("getting device authorization flow info")
|
||||
|
||||
// Try to initialize the Device Authorization Flow
|
||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err == nil {
|
||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
||||
}
|
||||
|
||||
log.Debugf("getting device authorization flow info failed with error: %v", err)
|
||||
log.Debugf("falling back to pkce authorization flow info")
|
||||
|
||||
// If Device Authorization Flow failed, try the PKCE Authorization Flow
|
||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
s, ok := gstatus.FromError(err)
|
||||
if ok && s.Code() == codes.NotFound {
|
||||
return nil, fmt.Errorf("no SSO provider returned from management. " +
|
||||
"If you are using hosting Netbird see documentation at " +
|
||||
"https://github.com/netbirdio/netbird/tree/main/management for details")
|
||||
} else if ok && s.Code() == codes.Unimplemented {
|
||||
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
||||
"please update your server or use Setup Keys to login", config.ManagementURL)
|
||||
} else {
|
||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
||||
}
|
||||
245
client/internal/auth/pkce_flow.go
Normal file
245
client/internal/auth/pkce_flow.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/templates"
|
||||
)
|
||||
|
||||
var _ OAuthFlow = &PKCEAuthorizationFlow{}
|
||||
|
||||
const (
|
||||
queryState = "state"
|
||||
queryCode = "code"
|
||||
queryError = "error"
|
||||
queryErrorDesc = "error_description"
|
||||
defaultPKCETimeoutSeconds = 300
|
||||
)
|
||||
|
||||
// PKCEAuthorizationFlow implements the OAuthFlow interface for
|
||||
// the Authorization Code Flow with PKCE.
|
||||
type PKCEAuthorizationFlow struct {
|
||||
providerConfig internal.PKCEAuthProviderConfig
|
||||
state string
|
||||
codeVerifier string
|
||||
oAuthConfig *oauth2.Config
|
||||
}
|
||||
|
||||
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
|
||||
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||
var availableRedirectURL string
|
||||
|
||||
// find the first available redirect URL
|
||||
for _, redirectURL := range config.RedirectURLs {
|
||||
if !isRedirectURLPortUsed(redirectURL) {
|
||||
availableRedirectURL = redirectURL
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if availableRedirectURL == "" {
|
||||
return nil, fmt.Errorf("no available port found from configured redirect URLs: %q", config.RedirectURLs)
|
||||
}
|
||||
|
||||
cfg := &oauth2.Config{
|
||||
ClientID: config.ClientID,
|
||||
ClientSecret: config.ClientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: config.AuthorizationEndpoint,
|
||||
TokenURL: config.TokenEndpoint,
|
||||
},
|
||||
RedirectURL: availableRedirectURL,
|
||||
Scopes: strings.Split(config.Scope, " "),
|
||||
}
|
||||
|
||||
return &PKCEAuthorizationFlow{
|
||||
providerConfig: config,
|
||||
oAuthConfig: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetClientID returns the provider client id
|
||||
func (p *PKCEAuthorizationFlow) GetClientID(_ context.Context) string {
|
||||
return p.providerConfig.ClientID
|
||||
}
|
||||
|
||||
// RequestAuthInfo requests a authorization code login flow information.
|
||||
func (p *PKCEAuthorizationFlow) RequestAuthInfo(_ context.Context) (AuthFlowInfo, error) {
|
||||
state, err := randomBytesInHex(24)
|
||||
if err != nil {
|
||||
return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err)
|
||||
}
|
||||
p.state = state
|
||||
|
||||
codeVerifier, err := randomBytesInHex(64)
|
||||
if err != nil {
|
||||
return AuthFlowInfo{}, fmt.Errorf("could not create a code verifier: %v", err)
|
||||
}
|
||||
p.codeVerifier = codeVerifier
|
||||
|
||||
codeChallenge := createCodeChallenge(codeVerifier)
|
||||
authURL := p.oAuthConfig.AuthCodeURL(
|
||||
state,
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
||||
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
||||
)
|
||||
|
||||
return AuthFlowInfo{
|
||||
VerificationURIComplete: authURL,
|
||||
ExpiresIn: defaultPKCETimeoutSeconds,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
|
||||
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
|
||||
// Once the token is received, it is converted to TokenInfo and validated before returning.
|
||||
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
|
||||
tokenChan := make(chan *oauth2.Token, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go p.startServer(tokenChan, errChan)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return TokenInfo{}, ctx.Err()
|
||||
case token := <-tokenChan:
|
||||
return p.handleOAuthToken(token)
|
||||
case err := <-errChan:
|
||||
return TokenInfo{}, err
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errChan chan<- error) {
|
||||
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("failed to parse redirect URL: %v", err)
|
||||
return
|
||||
}
|
||||
port := parsedURL.Port()
|
||||
|
||||
server := http.Server{Addr: fmt.Sprintf(":%s", port)}
|
||||
defer func() {
|
||||
if err := server.Shutdown(context.Background()); err != nil {
|
||||
log.Errorf("error while shutting down pkce flow server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
|
||||
tokenValidatorFunc := func() (*oauth2.Token, error) {
|
||||
query := req.URL.Query()
|
||||
|
||||
if authError := query.Get(queryError); authError != "" {
|
||||
authErrorDesc := query.Get(queryErrorDesc)
|
||||
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
|
||||
}
|
||||
|
||||
// Prevent timing attacks on state
|
||||
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
||||
return nil, fmt.Errorf("invalid state")
|
||||
}
|
||||
|
||||
code := query.Get(queryCode)
|
||||
if code == "" {
|
||||
return nil, fmt.Errorf("missing code")
|
||||
}
|
||||
|
||||
return p.oAuthConfig.Exchange(
|
||||
req.Context(),
|
||||
code,
|
||||
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
|
||||
)
|
||||
}
|
||||
|
||||
token, err := tokenValidatorFunc()
|
||||
if err != nil {
|
||||
renderPKCEFlowTmpl(w, err)
|
||||
errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
renderPKCEFlowTmpl(w, nil)
|
||||
tokenChan <- token
|
||||
})
|
||||
|
||||
if err := server.ListenAndServe(); err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PKCEAuthorizationFlow) handleOAuthToken(token *oauth2.Token) (TokenInfo, error) {
|
||||
tokenInfo := TokenInfo{
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
TokenType: token.TokenType,
|
||||
ExpiresIn: token.Expiry.Second(),
|
||||
UseIDToken: p.providerConfig.UseIDToken,
|
||||
}
|
||||
if idToken, ok := token.Extra("id_token").(string); ok {
|
||||
tokenInfo.IDToken = idToken
|
||||
}
|
||||
|
||||
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), p.providerConfig.Audience); err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
func createCodeChallenge(codeVerifier string) string {
|
||||
sha2 := sha256.Sum256([]byte(codeVerifier))
|
||||
return base64.RawURLEncoding.EncodeToString(sha2[:])
|
||||
}
|
||||
|
||||
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use.
|
||||
func isRedirectURLPortUsed(redirectURL string) bool {
|
||||
parsedURL, err := url.Parse(redirectURL)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse redirect URL: %v", err)
|
||||
return true
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf(":%s", parsedURL.Port())
|
||||
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("error while closing the connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) {
|
||||
tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
data := make(map[string]string)
|
||||
if authError != nil {
|
||||
data["Error"] = authError.Error()
|
||||
}
|
||||
|
||||
if err := tmpl.Execute(w, data); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
62
client/internal/auth/util.go
Normal file
62
client/internal/auth/util.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func randomBytesInHex(count int) (string, error) {
|
||||
buf := make([]byte, count)
|
||||
_, err := io.ReadFull(rand.Reader, buf)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not generate %d random bytes: %v", count, err)
|
||||
}
|
||||
|
||||
return hex.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
// isValidAccessToken is a simple validation of the access token
|
||||
func isValidAccessToken(token string, audience string) error {
|
||||
if token == "" {
|
||||
return fmt.Errorf("token received is empty")
|
||||
}
|
||||
|
||||
encodedClaims := strings.Split(token, ".")[1]
|
||||
claimsString, err := base64.RawURLEncoding.DecodeString(encodedClaims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
claims := Claims{}
|
||||
err = json.Unmarshal(claimsString, &claims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if claims.Audience == nil {
|
||||
return fmt.Errorf("required token field audience is absent")
|
||||
}
|
||||
|
||||
// Audience claim of JWT can be a string or an array of strings
|
||||
typ := reflect.TypeOf(claims.Audience)
|
||||
switch typ.Kind() {
|
||||
case reflect.String:
|
||||
if claims.Audience == audience {
|
||||
return nil
|
||||
}
|
||||
case reflect.Slice:
|
||||
for _, aud := range claims.Audience.([]interface{}) {
|
||||
if audience == aud {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid JWT token audience field")
|
||||
}
|
||||
@@ -215,10 +215,12 @@ func update(input ConfigInput) (*Config, error) {
|
||||
}
|
||||
|
||||
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
|
||||
log.Infof("new pre-shared key provided, updated to %s (old value %s)",
|
||||
*input.PreSharedKey, config.PreSharedKey)
|
||||
config.PreSharedKey = *input.PreSharedKey
|
||||
refresh = true
|
||||
if *input.PreSharedKey != "" {
|
||||
log.Infof("new pre-shared key provides, updated to %s (old value %s)",
|
||||
*input.PreSharedKey, config.PreSharedKey)
|
||||
config.PreSharedKey = *input.PreSharedKey
|
||||
refresh = true
|
||||
}
|
||||
}
|
||||
|
||||
if config.SSHKey == "" {
|
||||
|
||||
@@ -63,7 +63,22 @@ func TestGetConfig(t *testing.T) {
|
||||
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
||||
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
||||
|
||||
// case 4: existing config, but new managementURL has been provided -> update config
|
||||
// case 4: new empty pre-shared key config -> fetch it
|
||||
newPreSharedKey := ""
|
||||
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: path,
|
||||
PreSharedKey: &newPreSharedKey,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
||||
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
||||
|
||||
// case 5: existing config, but new managementURL has been provided -> update config
|
||||
newManagementURL := "https://test.newManagement.url:33071"
|
||||
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||
ManagementURL: newManagementURL,
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
@@ -24,7 +25,24 @@ import (
|
||||
)
|
||||
|
||||
// RunClient with main logic.
|
||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener) error {
|
||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
|
||||
return runClient(ctx, config, statusRecorder, MobileDependency{})
|
||||
}
|
||||
|
||||
// RunClientMobile with main logic on mobile system
|
||||
func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener, dnsAddresses []string, dnsReadyListener dns.ReadyListener) error {
|
||||
// in case of non Android os these variables will be nil
|
||||
mobileDependency := MobileDependency{
|
||||
TunAdapter: tunAdapter,
|
||||
IFaceDiscover: iFaceDiscover,
|
||||
RouteListener: routeListener,
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
}
|
||||
return runClient(ctx, config, statusRecorder, mobileDependency)
|
||||
}
|
||||
|
||||
func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status, mobileDependency MobileDependency) error {
|
||||
backOff := &backoff.ExponentialBackOff{
|
||||
InitialInterval: time.Second,
|
||||
RandomizationFactor: 1,
|
||||
@@ -151,14 +169,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
// in case of non Android os these variables will be nil
|
||||
md := MobileDependency{
|
||||
TunAdapter: tunAdapter,
|
||||
IFaceDiscover: iFaceDiscover,
|
||||
RouteListener: routeListener,
|
||||
}
|
||||
|
||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, md, statusRecorder)
|
||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder)
|
||||
err = engine.Start()
|
||||
if err != nil {
|
||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||
|
||||
@@ -16,11 +16,11 @@ import (
|
||||
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
||||
type DeviceAuthorizationFlow struct {
|
||||
Provider string
|
||||
ProviderConfig ProviderConfig
|
||||
ProviderConfig DeviceAuthProviderConfig
|
||||
}
|
||||
|
||||
// ProviderConfig has all attributes needed to initiate a device authorization flow
|
||||
type ProviderConfig struct {
|
||||
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
||||
type DeviceAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
@@ -88,7 +88,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
|
||||
deviceAuthorizationFlow := DeviceAuthorizationFlow{
|
||||
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
||||
|
||||
ProviderConfig: ProviderConfig{
|
||||
ProviderConfig: DeviceAuthProviderConfig{
|
||||
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
|
||||
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
|
||||
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
||||
@@ -105,7 +105,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
|
||||
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
|
||||
}
|
||||
|
||||
err = isProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
||||
err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
||||
if err != nil {
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
@@ -113,7 +113,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
|
||||
return deviceAuthorizationFlow, nil
|
||||
}
|
||||
|
||||
func isProviderConfigValid(config ProviderConfig) error {
|
||||
func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
|
||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
if config.Audience == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Audience")
|
||||
|
||||
@@ -15,7 +15,8 @@ const (
|
||||
fileGeneratedResolvConfSearchBeginContent = "search "
|
||||
fileGeneratedResolvConfContentFormat = fileGeneratedResolvConfContentHeader +
|
||||
"\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" +
|
||||
fileGeneratedResolvConfSearchBeginContent + "%s\n"
|
||||
fileGeneratedResolvConfSearchBeginContent + "%s\n\n" +
|
||||
"%s\n"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -91,7 +92,12 @@ func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
searchDomains += " " + dConf.domain
|
||||
appendedDomains++
|
||||
}
|
||||
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains)
|
||||
|
||||
originalContent, err := os.ReadFile(fileDefaultResolvConfBackupLocation)
|
||||
if err != nil {
|
||||
log.Errorf("Could not read existing resolv.conf")
|
||||
}
|
||||
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains, string(originalContent))
|
||||
err = writeDNSConfig(content, defaultResolvConfPath, f.originalPerms)
|
||||
if err != nil {
|
||||
err = f.restore()
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
type androidHostManager struct {
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
|
||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||
return &androidHostManager{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -9,8 +9,6 @@ import (
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -34,7 +32,7 @@ type systemConfigurator struct {
|
||||
createdKeys map[string]struct{}
|
||||
}
|
||||
|
||||
func newHostManager(_ *iface.WGIface) (hostManager, error) {
|
||||
func newHostManager(_ WGIface) (hostManager, error) {
|
||||
return &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}, nil
|
||||
@@ -184,12 +182,11 @@ func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error {
|
||||
primaryServiceKey := s.getPrimaryService()
|
||||
primaryServiceKey, existingNameserver := s.getPrimaryService()
|
||||
if primaryServiceKey == "" {
|
||||
return fmt.Errorf("couldn't find the primary service key")
|
||||
}
|
||||
|
||||
err := s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port)
|
||||
err := s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -198,27 +195,32 @@ func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) getPrimaryService() string {
|
||||
func (s *systemConfigurator) getPrimaryService() (string, string) {
|
||||
line := buildCommandLine("show", globalIPv4State, "")
|
||||
stdinCommands := wrapCommand(line)
|
||||
b, err := runSystemConfigCommand(stdinCommands)
|
||||
if err != nil {
|
||||
log.Error("got error while sending the command: ", err)
|
||||
return ""
|
||||
return "", ""
|
||||
}
|
||||
scanner := bufio.NewScanner(bytes.NewReader(b))
|
||||
primaryService := ""
|
||||
router := ""
|
||||
for scanner.Scan() {
|
||||
text := scanner.Text()
|
||||
if strings.Contains(text, "PrimaryService") {
|
||||
return strings.TrimSpace(strings.Split(text, ":")[1])
|
||||
primaryService = strings.TrimSpace(strings.Split(text, ":")[1])
|
||||
}
|
||||
if strings.Contains(text, "Router") {
|
||||
router = strings.TrimSpace(strings.Split(text, ":")[1])
|
||||
}
|
||||
}
|
||||
return ""
|
||||
return primaryService, router
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int) error {
|
||||
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, existingDNSServer string) error {
|
||||
lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0))
|
||||
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer)
|
||||
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer+" "+existingDNSServer)
|
||||
lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port))
|
||||
addDomainCommand := buildCreateStateWithOperation(setupKey, lines)
|
||||
stdinCommands := wrapCommand(addDomainCommand)
|
||||
|
||||
@@ -5,10 +5,10 @@ package dns
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -25,7 +25,7 @@ const (
|
||||
|
||||
type osManagerType int
|
||||
|
||||
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
|
||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||
osManager, err := getOSDNSManagerType()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -6,8 +6,6 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -33,7 +31,7 @@ type registryConfigurator struct {
|
||||
existingSearchDomains []string
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
|
||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||
guid, err := wgInterface.GetInterfaceGUIDString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -31,6 +31,11 @@ func (m *MockServer) DnsIP() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
// UpdateDNSServer mock implementation of UpdateDNSServer from Server interface
|
||||
func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||
if m.UpdateDNSServerFunc != nil {
|
||||
|
||||
@@ -14,8 +14,6 @@ import (
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -72,7 +70,7 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() {
|
||||
}
|
||||
}
|
||||
|
||||
func newNetworkManagerDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
||||
func newNetworkManagerDbusConfigurator(wgInterface WGIface) (hostManager, error) {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -4,12 +4,11 @@ package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const resolvconfCommand = "resolvconf"
|
||||
@@ -18,7 +17,7 @@ type resolvconf struct {
|
||||
ifaceName string
|
||||
}
|
||||
|
||||
func newResolvConfConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
||||
func newResolvConfConfigurator(wgInterface WGIface) (hostManager, error) {
|
||||
return &resolvconf{
|
||||
ifaceName: wgInterface.Name(),
|
||||
}, nil
|
||||
@@ -61,7 +60,11 @@ func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error {
|
||||
appendedDomains++
|
||||
}
|
||||
|
||||
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains)
|
||||
originalContent, err := os.ReadFile(fileDefaultResolvConfBackupLocation)
|
||||
if err != nil {
|
||||
log.Errorf("Could not read existing resolv.conf")
|
||||
}
|
||||
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains, string(originalContent))
|
||||
|
||||
err = r.applyConfig(content)
|
||||
if err != nil {
|
||||
|
||||
@@ -3,29 +3,20 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/mitchellh/hashstructure/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPort = 53
|
||||
customPort = 5053
|
||||
defaultIP = "127.0.0.1"
|
||||
customIP = "127.0.0.153"
|
||||
)
|
||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||
type ReadyListener interface {
|
||||
OnReady()
|
||||
}
|
||||
|
||||
// Server is a dns server interface
|
||||
type Server interface {
|
||||
@@ -33,6 +24,7 @@ type Server interface {
|
||||
Stop()
|
||||
DnsIP() string
|
||||
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
||||
OnUpdatedHostDNSServer(strings []string)
|
||||
}
|
||||
|
||||
type registeredHandlerMap map[string]handlerWithStop
|
||||
@@ -42,21 +34,19 @@ type DefaultServer struct {
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
mux sync.Mutex
|
||||
udpFilterHookID string
|
||||
server *dns.Server
|
||||
dnsMux *dns.ServeMux
|
||||
service service
|
||||
dnsMuxMap registeredHandlerMap
|
||||
localResolver *localResolver
|
||||
wgInterface *iface.WGIface
|
||||
wgInterface WGIface
|
||||
hostManager hostManager
|
||||
updateSerial uint64
|
||||
listenerIsRunning bool
|
||||
runtimePort int
|
||||
runtimeIP string
|
||||
previousConfigHash uint64
|
||||
currentConfig hostDNSConfig
|
||||
customAddress *netip.AddrPort
|
||||
enabled bool
|
||||
|
||||
// permanent related properties
|
||||
permanent bool
|
||||
hostsDnsList []string
|
||||
hostsDnsListLock sync.Mutex
|
||||
}
|
||||
|
||||
type handlerWithStop interface {
|
||||
@@ -70,9 +60,7 @@ type muxUpdate struct {
|
||||
}
|
||||
|
||||
// NewDefaultServer returns a new dns server
|
||||
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string, initialDnsCfg *nbdns.Config) (*DefaultServer, error) {
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string) (*DefaultServer, error) {
|
||||
var addrPort *netip.AddrPort
|
||||
if customAddress != "" {
|
||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||
@@ -82,37 +70,44 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
|
||||
addrPort = &parsedAddrPort
|
||||
}
|
||||
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
var dnsService service
|
||||
if wgInterface.IsUserspaceBind() {
|
||||
dnsService = newServiceViaMemory(wgInterface)
|
||||
} else {
|
||||
dnsService = newServiceViaListener(wgInterface, addrPort)
|
||||
}
|
||||
|
||||
return newDefaultServer(ctx, wgInterface, dnsService), nil
|
||||
}
|
||||
|
||||
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
||||
func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string) *DefaultServer {
|
||||
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
|
||||
ds.permanent = true
|
||||
ds.hostsDnsList = hostsDnsList
|
||||
ds.addHostRootZone()
|
||||
setServerDns(ds)
|
||||
return ds
|
||||
}
|
||||
|
||||
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer {
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
ctxCancel: stop,
|
||||
server: &dns.Server{
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
UDPSize: 65535,
|
||||
},
|
||||
dnsMux: mux,
|
||||
service: dnsService,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
wgInterface: wgInterface,
|
||||
customAddress: addrPort,
|
||||
wgInterface: wgInterface,
|
||||
}
|
||||
|
||||
if initialDnsCfg != nil {
|
||||
defaultServer.enabled = hasValidDnsServer(initialDnsCfg)
|
||||
}
|
||||
|
||||
if wgInterface.IsUserspaceBind() {
|
||||
defaultServer.evelRuntimeAddressForUserspace()
|
||||
}
|
||||
|
||||
return defaultServer, nil
|
||||
return defaultServer
|
||||
}
|
||||
|
||||
// Initialize instantiate host manager. It required to be initialized wginterface
|
||||
// Initialize instantiate host manager and the dns service
|
||||
func (s *DefaultServer) Initialize() (err error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
@@ -121,72 +116,23 @@ func (s *DefaultServer) Initialize() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !s.wgInterface.IsUserspaceBind() {
|
||||
s.evalRuntimeAddress()
|
||||
if s.permanent {
|
||||
err = s.service.Listen()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
s.hostManager, err = newHostManager(s.wgInterface)
|
||||
return
|
||||
}
|
||||
|
||||
// listen runs the listener in a go routine
|
||||
func (s *DefaultServer) listen() {
|
||||
// nil check required in unit tests
|
||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
|
||||
s.udpFilterHookID = s.filterDNSTraffic()
|
||||
s.setListenerStatus(true)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("starting dns on %s", s.server.Addr)
|
||||
|
||||
go func() {
|
||||
s.setListenerStatus(true)
|
||||
defer s.setListenerStatus(false)
|
||||
|
||||
err := s.server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// DnsIP returns the DNS resolver server IP address
|
||||
//
|
||||
// When kernel space interface used it return real DNS server listener IP address
|
||||
// For bind interface, fake DNS resolver address returned (second last IP address from Nebird network)
|
||||
func (s *DefaultServer) DnsIP() string {
|
||||
if !s.enabled {
|
||||
return ""
|
||||
}
|
||||
return s.runtimeIP
|
||||
}
|
||||
|
||||
func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) {
|
||||
ips := []string{defaultIP, customIP}
|
||||
if runtime.GOOS != "darwin" && s.wgInterface != nil {
|
||||
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
|
||||
}
|
||||
ports := []int{defaultPort, customPort}
|
||||
for _, port := range ports {
|
||||
for _, ip := range ips {
|
||||
addrString := fmt.Sprintf("%s:%d", ip, port)
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||
if err == nil {
|
||||
err = probeListener.Close()
|
||||
if err != nil {
|
||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||
}
|
||||
return ip, port, nil
|
||||
}
|
||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||
}
|
||||
}
|
||||
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) setListenerStatus(running bool) {
|
||||
s.listenerIsRunning = running
|
||||
return s.service.RuntimeIP()
|
||||
}
|
||||
|
||||
// Stop stops the server
|
||||
@@ -202,37 +148,23 @@ func (s *DefaultServer) Stop() {
|
||||
}
|
||||
}
|
||||
|
||||
err := s.stopListener()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
s.service.Stop()
|
||||
}
|
||||
|
||||
func (s *DefaultServer) stopListener() error {
|
||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
|
||||
// udpFilterHookID here empty only in the unit tests
|
||||
if filter := s.wgInterface.GetFilter(); filter != nil && s.udpFilterHookID != "" {
|
||||
if err := filter.RemovePacketHook(s.udpFilterHookID); err != nil {
|
||||
log.Errorf("unable to remove DNS packet hook: %s", err)
|
||||
}
|
||||
}
|
||||
s.udpFilterHookID = ""
|
||||
s.listenerIsRunning = false
|
||||
return nil
|
||||
}
|
||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
||||
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
||||
s.hostsDnsListLock.Lock()
|
||||
defer s.hostsDnsListLock.Unlock()
|
||||
|
||||
if !s.listenerIsRunning {
|
||||
return nil
|
||||
s.hostsDnsList = hostsDnsList
|
||||
_, ok := s.dnsMuxMap[nbdns.RootZone]
|
||||
if ok {
|
||||
log.Debugf("on new host DNS config but skip to apply it")
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := s.server.ShutdownContext(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
|
||||
}
|
||||
return nil
|
||||
log.Debugf("update host DNS settings: %+v", hostsDnsList)
|
||||
s.addHostRootZone()
|
||||
}
|
||||
|
||||
// UpdateDNSServer processes an update received from the management service
|
||||
@@ -283,12 +215,10 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
|
||||
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
// is the service should be disabled, we stop the listener or fake resolver
|
||||
// and proceed with a regular update to clean up the handlers and records
|
||||
if !update.ServiceEnable {
|
||||
if err := s.stopListener(); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
} else if !s.listenerIsRunning {
|
||||
s.listen()
|
||||
if update.ServiceEnable {
|
||||
_ = s.service.Listen()
|
||||
} else if !s.permanent {
|
||||
s.service.Stop()
|
||||
}
|
||||
|
||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
@@ -299,15 +229,14 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
}
|
||||
|
||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
||||
|
||||
s.updateMux(muxUpdates)
|
||||
s.updateLocalResolver(localRecords)
|
||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)
|
||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
||||
|
||||
hostUpdate := s.currentConfig
|
||||
if s.runtimePort != defaultPort && !s.hostManager.supportCustomPort() {
|
||||
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
|
||||
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
||||
"Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver")
|
||||
hostUpdate.routeAll = false
|
||||
@@ -412,19 +341,32 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
||||
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||
muxUpdateMap := make(registeredHandlerMap)
|
||||
|
||||
var isContainRootUpdate bool
|
||||
|
||||
for _, update := range muxUpdates {
|
||||
s.registerMux(update.domain, update.handler)
|
||||
s.service.RegisterMux(update.domain, update.handler)
|
||||
muxUpdateMap[update.domain] = update.handler
|
||||
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
||||
existingHandler.stop()
|
||||
}
|
||||
|
||||
if update.domain == nbdns.RootZone {
|
||||
isContainRootUpdate = true
|
||||
}
|
||||
}
|
||||
|
||||
for key, existingHandler := range s.dnsMuxMap {
|
||||
_, found := muxUpdateMap[key]
|
||||
if !found {
|
||||
existingHandler.stop()
|
||||
s.deregisterMux(key)
|
||||
if !isContainRootUpdate && key == nbdns.RootZone {
|
||||
s.hostsDnsListLock.Lock()
|
||||
s.addHostRootZone()
|
||||
s.hostsDnsListLock.Unlock()
|
||||
existingHandler.stop()
|
||||
} else {
|
||||
existingHandler.stop()
|
||||
s.service.DeregisterMux(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -455,14 +397,6 @@ func getNSHostPort(ns nbdns.NameServer) string {
|
||||
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
|
||||
s.dnsMux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) deregisterMux(pattern string) {
|
||||
s.dnsMux.HandleRemove(pattern)
|
||||
}
|
||||
|
||||
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
||||
// the upstream resolver from the configuration, the second one is used to
|
||||
// reactivate it. Not allowed to call reactivate before deactivate.
|
||||
@@ -490,7 +424,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
for i, item := range s.currentConfig.domains {
|
||||
if _, found := removeIndex[item.domain]; found {
|
||||
s.currentConfig.domains[i].disabled = true
|
||||
s.deregisterMux(item.domain)
|
||||
s.service.DeregisterMux(item.domain)
|
||||
removeIndex[item.domain] = i
|
||||
}
|
||||
}
|
||||
@@ -507,7 +441,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
continue
|
||||
}
|
||||
s.currentConfig.domains[i].disabled = false
|
||||
s.registerMux(domain, handler)
|
||||
s.service.RegisterMux(domain, handler)
|
||||
}
|
||||
|
||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||
@@ -523,93 +457,13 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
return
|
||||
}
|
||||
|
||||
func (s *DefaultServer) filterDNSTraffic() string {
|
||||
filter := s.wgInterface.GetFilter()
|
||||
if filter == nil {
|
||||
log.Error("can't set DNS filter, filter not initialized")
|
||||
return ""
|
||||
func (s *DefaultServer) addHostRootZone() {
|
||||
handler := newUpstreamResolver(s.ctx)
|
||||
handler.upstreamServers = make([]string, len(s.hostsDnsList))
|
||||
for n, ua := range s.hostsDnsList {
|
||||
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ua)
|
||||
}
|
||||
|
||||
firstLayerDecoder := layers.LayerTypeIPv4
|
||||
if s.wgInterface.Address().Network.IP.To4() == nil {
|
||||
firstLayerDecoder = layers.LayerTypeIPv6
|
||||
}
|
||||
|
||||
hook := func(packetData []byte) bool {
|
||||
// Decode the packet
|
||||
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
||||
|
||||
// Get the UDP layer
|
||||
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
||||
udp := udpLayer.(*layers.UDP)
|
||||
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(udp.Payload); err != nil {
|
||||
log.Tracef("parse DNS request: %v", err)
|
||||
return true
|
||||
}
|
||||
|
||||
writer := responseWriter{
|
||||
packet: packet,
|
||||
device: s.wgInterface.GetDevice().Device,
|
||||
}
|
||||
go s.dnsMux.ServeDNS(&writer, msg)
|
||||
return true
|
||||
}
|
||||
|
||||
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) evelRuntimeAddressForUserspace() {
|
||||
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
|
||||
s.runtimePort = defaultPort
|
||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) evalRuntimeAddress() {
|
||||
defer func() {
|
||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
||||
}()
|
||||
|
||||
if s.customAddress != nil {
|
||||
s.runtimeIP = s.customAddress.Addr().String()
|
||||
s.runtimePort = int(s.customAddress.Port())
|
||||
return
|
||||
}
|
||||
|
||||
ip, port, err := s.getFirstListenerAvailable()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
s.runtimeIP = ip
|
||||
s.runtimePort = port
|
||||
}
|
||||
|
||||
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
|
||||
// Calculate the last IP in the CIDR range
|
||||
var endIP net.IP
|
||||
for i := 0; i < len(network.IP); i++ {
|
||||
endIP = append(endIP, network.IP[i]|^network.Mask[i])
|
||||
}
|
||||
|
||||
// convert to big.Int
|
||||
endInt := big.NewInt(0)
|
||||
endInt.SetBytes(endIP)
|
||||
|
||||
// subtract fromEnd from the last ip
|
||||
fromEndBig := big.NewInt(int64(fromEnd))
|
||||
resultInt := big.NewInt(0)
|
||||
resultInt.Sub(endInt, fromEndBig)
|
||||
|
||||
return net.IP(resultInt.Bytes()).String()
|
||||
}
|
||||
|
||||
func hasValidDnsServer(cfg *nbdns.Config) bool {
|
||||
for _, c := range cfg.NameServerGroups {
|
||||
if c.Primary {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
handler.deactivate = func() {}
|
||||
handler.reactivate = func() {}
|
||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
||||
}
|
||||
|
||||
29
client/internal/dns/server_export.go
Normal file
29
client/internal/dns/server_export.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
mutex sync.Mutex
|
||||
server Server
|
||||
)
|
||||
|
||||
// GetServerDns export the DNS server instance in static way. It used by the Mobile client
|
||||
func GetServerDns() (Server, error) {
|
||||
mutex.Lock()
|
||||
if server == nil {
|
||||
mutex.Unlock()
|
||||
return nil, fmt.Errorf("DNS server not instantiated yet")
|
||||
}
|
||||
s := server
|
||||
mutex.Unlock()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func setServerDns(newServerServer Server) {
|
||||
mutex.Lock()
|
||||
server = newServerServer
|
||||
defer mutex.Unlock()
|
||||
}
|
||||
24
client/internal/dns/server_export_test.go
Normal file
24
client/internal/dns/server_export_test.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetServerDns(t *testing.T) {
|
||||
_, err := GetServerDns()
|
||||
if err == nil {
|
||||
t.Errorf("invalid dns server instance")
|
||||
}
|
||||
|
||||
srv := &MockServer{}
|
||||
setServerDns(srv)
|
||||
|
||||
srvB, err := GetServerDns()
|
||||
if err != nil {
|
||||
t.Errorf("invalid dns server instance: %s", err)
|
||||
}
|
||||
|
||||
if srvB != srv {
|
||||
t.Errorf("missmatch dns instances")
|
||||
}
|
||||
}
|
||||
@@ -11,14 +11,53 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
pfmock "github.com/netbirdio/netbird/iface/mocks"
|
||||
)
|
||||
|
||||
type mocWGIface struct {
|
||||
filter iface.PacketFilter
|
||||
}
|
||||
|
||||
func (w *mocWGIface) Name() string {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (w *mocWGIface) Address() iface.WGAddress {
|
||||
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
|
||||
return iface.WGAddress{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *mocWGIface) GetFilter() iface.PacketFilter {
|
||||
return w.filter
|
||||
}
|
||||
|
||||
func (w *mocWGIface) GetDevice() *iface.DeviceWrapper {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (w *mocWGIface) GetInterfaceGUIDString() (string, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (w *mocWGIface) IsUserspaceBind() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *mocWGIface) SetFilter(filter iface.PacketFilter) error {
|
||||
w.filter = filter
|
||||
return nil
|
||||
}
|
||||
|
||||
var zoneRecords = []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "peera.netbird.cloud",
|
||||
@@ -29,6 +68,11 @@ var zoneRecords = []nbdns.SimpleRecord{
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
log.SetLevel(log.TraceLevel)
|
||||
formatter.SetTextFormatter(log.StandardLogger())
|
||||
}
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
@@ -224,7 +268,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil)
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -242,8 +286,6 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
||||
dnsServer.localResolver.registeredMap = testCase.initLocalMap
|
||||
dnsServer.updateSerial = testCase.initSerial
|
||||
// pretend we are running
|
||||
dnsServer.listenerIsRunning = true
|
||||
|
||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||
if err != nil {
|
||||
@@ -282,7 +324,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||
|
||||
os.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
_ = os.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
newNet, err := stdnet.NewNet(nil)
|
||||
if err != nil {
|
||||
t.Errorf("create stdnet: %v", err)
|
||||
@@ -316,17 +358,17 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
}
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().SetNetwork(ipNet)
|
||||
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||
packetfilter.EXPECT().SetNetwork(ipNet)
|
||||
|
||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||
t.Errorf("set packet filter: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil)
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
|
||||
if err != nil {
|
||||
t.Errorf("create DNS server: %v", err)
|
||||
return
|
||||
@@ -421,21 +463,23 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
dnsServer := getDefaultServerWithNoHostManager(t, testCase.addrPort)
|
||||
|
||||
dnsServer.hostManager = newNoopHostMocker()
|
||||
dnsServer.listen()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if !dnsServer.listenerIsRunning {
|
||||
t.Fatal("dns server listener is not running")
|
||||
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort)
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
}
|
||||
dnsServer.hostManager = newNoopHostMocker()
|
||||
err = dnsServer.service.Listen()
|
||||
if err != nil {
|
||||
t.Fatalf("dns server is not running: %s", err)
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
defer dnsServer.Stop()
|
||||
err := dnsServer.localResolver.registerRecord(zoneRecords[0])
|
||||
err = dnsServer.localResolver.registerRecord(zoneRecords[0])
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
dnsServer.dnsMux.Handle("netbird.cloud", dnsServer.localResolver)
|
||||
dnsServer.service.RegisterMux("netbird.cloud", dnsServer.localResolver)
|
||||
|
||||
resolver := &net.Resolver{
|
||||
PreferGo: true,
|
||||
@@ -443,7 +487,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
d := net.Dialer{
|
||||
Timeout: time.Second * 5,
|
||||
}
|
||||
addr := fmt.Sprintf("%s:%d", dnsServer.runtimeIP, dnsServer.runtimePort)
|
||||
addr := fmt.Sprintf("%s:%d", dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||
conn, err := d.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
@@ -478,7 +522,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
hostManager := &mockHostConfigurator{}
|
||||
server := DefaultServer{
|
||||
dnsMux: dns.DefaultServeMux,
|
||||
service: newServiceViaMemory(&mocWGIface{}),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
@@ -541,62 +585,237 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultServer {
|
||||
mux := dns.NewServeMux()
|
||||
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||
wgIFace, err := createWgInterfaceWithBind(t)
|
||||
if err != nil {
|
||||
t.Fatal("failed to initialize wg interface")
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
|
||||
var parsedAddrPort *netip.AddrPort
|
||||
if addrPort != "" {
|
||||
parsed, err := netip.ParseAddrPort(addrPort)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
parsedAddrPort = &parsed
|
||||
var dnsList []string
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList)
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
defer dnsServer.Stop()
|
||||
|
||||
dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"})
|
||||
|
||||
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSPermanent_updateUpstream(t *testing.T) {
|
||||
wgIFace, err := createWgInterfaceWithBind(t)
|
||||
if err != nil {
|
||||
t.Fatal("failed to initialize wg interface")
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"})
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
defer dnsServer.Stop()
|
||||
|
||||
// check initial state
|
||||
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
|
||||
dnsServer := &dns.Server{
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
UDPSize: 65535,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
|
||||
ds := &DefaultServer{
|
||||
ctx: ctx,
|
||||
ctxCancel: cancel,
|
||||
server: dnsServer,
|
||||
dnsMux: mux,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
update := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
},
|
||||
Enabled: true,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
customAddress: parsedAddrPort,
|
||||
}
|
||||
ds.evalRuntimeAddress()
|
||||
return ds
|
||||
}
|
||||
|
||||
func TestGetLastIPFromNetwork(t *testing.T) {
|
||||
tests := []struct {
|
||||
addr string
|
||||
ip string
|
||||
}{
|
||||
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
|
||||
{"192.168.0.0/30", "192.168.0.2"},
|
||||
{"192.168.0.0/16", "192.168.255.254"},
|
||||
{"192.168.0.0/24", "192.168.0.254"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
_, ipnet, err := net.ParseCIDR(tt.addr)
|
||||
if err != nil {
|
||||
t.Errorf("Error parsing CIDR: %v", err)
|
||||
return
|
||||
}
|
||||
err = dnsServer.UpdateDNSServer(1, update)
|
||||
if err != nil {
|
||||
t.Errorf("failed to update dns server: %s", err)
|
||||
}
|
||||
|
||||
lastIP := getLastIPFromNetwork(ipnet, 1)
|
||||
if lastIP != tt.ip {
|
||||
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
|
||||
}
|
||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
||||
if err != nil {
|
||||
t.Fatalf("failed resolve zone record: %v", err)
|
||||
}
|
||||
if ips[0] != zoneRecords[0].RData {
|
||||
t.Fatalf("invalid zone record: %v", err)
|
||||
}
|
||||
|
||||
update2 := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{},
|
||||
}
|
||||
|
||||
err = dnsServer.UpdateDNSServer(2, update2)
|
||||
if err != nil {
|
||||
t.Errorf("failed to update dns server: %s", err)
|
||||
}
|
||||
|
||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
|
||||
ips, err = resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
||||
if err != nil {
|
||||
t.Fatalf("failed resolve zone record: %v", err)
|
||||
}
|
||||
if ips[0] != zoneRecords[0].RData {
|
||||
t.Fatalf("invalid zone record: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
wgIFace, err := createWgInterfaceWithBind(t)
|
||||
if err != nil {
|
||||
t.Fatal("failed to initialize wg interface")
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"})
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
defer dnsServer.Stop()
|
||||
|
||||
// check initial state
|
||||
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
|
||||
update := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
},
|
||||
Domains: []string{"customdomain.com"},
|
||||
Primary: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = dnsServer.UpdateDNSServer(1, update)
|
||||
if err != nil {
|
||||
t.Errorf("failed to update dns server: %s", err)
|
||||
}
|
||||
|
||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
||||
if err != nil {
|
||||
t.Fatalf("failed resolve zone record: %v", err)
|
||||
}
|
||||
if ips[0] != zoneRecords[0].RData {
|
||||
t.Fatalf("invalid zone record: %v", err)
|
||||
}
|
||||
_, err = resolver.LookupHost(context.Background(), "customdomain.com")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||
|
||||
_ = os.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
newNet, err := stdnet.NewNet(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("create stdnet: %v", err)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", iface.DefaultMTU, nil, newNet)
|
||||
if err != nil {
|
||||
t.Fatalf("build interface wireguard: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = wgIface.Create()
|
||||
if err != nil {
|
||||
t.Fatalf("crate and init wireguard interface: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pf, err := uspfilter.Create(wgIface)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create uspfilter: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = wgIface.SetFilter(pf)
|
||||
if err != nil {
|
||||
t.Fatalf("set packet filter: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wgIface, nil
|
||||
}
|
||||
|
||||
func newDnsResolver(ip string, port int) *net.Resolver {
|
||||
return &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d := net.Dialer{
|
||||
Timeout: time.Second * 3,
|
||||
}
|
||||
addr := fmt.Sprintf("%s:%d", ip, port)
|
||||
return d.DialContext(ctx, network, addr)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
18
client/internal/dns/service.go
Normal file
18
client/internal/dns/service.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPort = 53
|
||||
)
|
||||
|
||||
type service interface {
|
||||
Listen() error
|
||||
Stop()
|
||||
RegisterMux(domain string, handler dns.Handler)
|
||||
DeregisterMux(key string)
|
||||
RuntimePort() int
|
||||
RuntimeIP() string
|
||||
}
|
||||
145
client/internal/dns/service_listener.go
Normal file
145
client/internal/dns/service_listener.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
customPort = 5053
|
||||
defaultIP = "127.0.0.1"
|
||||
customIP = "127.0.0.153"
|
||||
)
|
||||
|
||||
type serviceViaListener struct {
|
||||
wgInterface WGIface
|
||||
dnsMux *dns.ServeMux
|
||||
customAddr *netip.AddrPort
|
||||
server *dns.Server
|
||||
runtimeIP string
|
||||
runtimePort int
|
||||
listenerIsRunning bool
|
||||
listenerFlagLock sync.Mutex
|
||||
}
|
||||
|
||||
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
s := &serviceViaListener{
|
||||
wgInterface: wgIface,
|
||||
dnsMux: mux,
|
||||
customAddr: customAddr,
|
||||
server: &dns.Server{
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
UDPSize: 65535,
|
||||
},
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) Listen() error {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if s.listenerIsRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
s.runtimeIP, s.runtimePort, err = s.evalRuntimeAddress()
|
||||
if err != nil {
|
||||
log.Errorf("failed to eval runtime address: %s", err)
|
||||
return err
|
||||
}
|
||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
||||
|
||||
log.Debugf("starting dns on %s", s.server.Addr)
|
||||
go func() {
|
||||
s.setListenerStatus(true)
|
||||
defer s.setListenerStatus(false)
|
||||
|
||||
err := s.server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) Stop() {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if !s.listenerIsRunning {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := s.server.ShutdownContext(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("stopping dns server listener returned an error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
||||
s.dnsMux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) DeregisterMux(pattern string) {
|
||||
s.dnsMux.HandleRemove(pattern)
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) RuntimePort() int {
|
||||
return s.runtimePort
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) RuntimeIP() string {
|
||||
return s.runtimeIP
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) setListenerStatus(running bool) {
|
||||
s.listenerIsRunning = running
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) getFirstListenerAvailable() (string, int, error) {
|
||||
ips := []string{defaultIP, customIP}
|
||||
if runtime.GOOS != "darwin" {
|
||||
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
|
||||
}
|
||||
ports := []int{defaultPort, customPort}
|
||||
for _, port := range ports {
|
||||
for _, ip := range ips {
|
||||
addrString := fmt.Sprintf("%s:%d", ip, port)
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||
if err == nil {
|
||||
err = probeListener.Close()
|
||||
if err != nil {
|
||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||
}
|
||||
return ip, port, nil
|
||||
}
|
||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||
}
|
||||
}
|
||||
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) evalRuntimeAddress() (string, int, error) {
|
||||
if s.customAddr != nil {
|
||||
return s.customAddr.Addr().String(), int(s.customAddr.Port()), nil
|
||||
}
|
||||
|
||||
return s.getFirstListenerAvailable()
|
||||
}
|
||||
139
client/internal/dns/service_memory.go
Normal file
139
client/internal/dns/service_memory.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type serviceViaMemory struct {
|
||||
wgInterface WGIface
|
||||
dnsMux *dns.ServeMux
|
||||
runtimeIP string
|
||||
runtimePort int
|
||||
udpFilterHookID string
|
||||
listenerIsRunning bool
|
||||
listenerFlagLock sync.Mutex
|
||||
}
|
||||
|
||||
func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
|
||||
s := &serviceViaMemory{
|
||||
wgInterface: wgIface,
|
||||
dnsMux: dns.NewServeMux(),
|
||||
|
||||
runtimeIP: getLastIPFromNetwork(wgIface.Address().Network, 1),
|
||||
runtimePort: defaultPort,
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) Listen() error {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if s.listenerIsRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
s.udpFilterHookID, err = s.filterDNSTraffic()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.listenerIsRunning = true
|
||||
|
||||
log.Debugf("dns service listening on: %s", s.RuntimeIP())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) Stop() {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if !s.listenerIsRunning {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
|
||||
log.Errorf("unable to remove DNS packet hook: %s", err)
|
||||
}
|
||||
|
||||
s.listenerIsRunning = false
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
||||
s.dnsMux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) DeregisterMux(pattern string) {
|
||||
s.dnsMux.HandleRemove(pattern)
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) RuntimePort() int {
|
||||
return s.runtimePort
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) RuntimeIP() string {
|
||||
return s.runtimeIP
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) filterDNSTraffic() (string, error) {
|
||||
filter := s.wgInterface.GetFilter()
|
||||
if filter == nil {
|
||||
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
|
||||
}
|
||||
|
||||
firstLayerDecoder := layers.LayerTypeIPv4
|
||||
if s.wgInterface.Address().Network.IP.To4() == nil {
|
||||
firstLayerDecoder = layers.LayerTypeIPv6
|
||||
}
|
||||
|
||||
hook := func(packetData []byte) bool {
|
||||
// Decode the packet
|
||||
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
||||
|
||||
// Get the UDP layer
|
||||
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
||||
udp := udpLayer.(*layers.UDP)
|
||||
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(udp.Payload); err != nil {
|
||||
log.Tracef("parse DNS request: %v", err)
|
||||
return true
|
||||
}
|
||||
|
||||
writer := responseWriter{
|
||||
packet: packet,
|
||||
device: s.wgInterface.GetDevice().Device,
|
||||
}
|
||||
go s.dnsMux.ServeDNS(&writer, msg)
|
||||
return true
|
||||
}
|
||||
|
||||
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
|
||||
}
|
||||
|
||||
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
|
||||
// Calculate the last IP in the CIDR range
|
||||
var endIP net.IP
|
||||
for i := 0; i < len(network.IP); i++ {
|
||||
endIP = append(endIP, network.IP[i]|^network.Mask[i])
|
||||
}
|
||||
|
||||
// convert to big.Int
|
||||
endInt := big.NewInt(0)
|
||||
endInt.SetBytes(endIP)
|
||||
|
||||
// subtract fromEnd from the last ip
|
||||
fromEndBig := big.NewInt(int64(fromEnd))
|
||||
resultInt := big.NewInt(0)
|
||||
resultInt.Sub(endInt, fromEndBig)
|
||||
|
||||
return net.IP(resultInt.Bytes()).String()
|
||||
}
|
||||
31
client/internal/dns/service_memory_test.go
Normal file
31
client/internal/dns/service_memory_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetLastIPFromNetwork(t *testing.T) {
|
||||
tests := []struct {
|
||||
addr string
|
||||
ip string
|
||||
}{
|
||||
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
|
||||
{"192.168.0.0/30", "192.168.0.2"},
|
||||
{"192.168.0.0/16", "192.168.255.254"},
|
||||
{"192.168.0.0/24", "192.168.0.254"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
_, ipnet, err := net.ParseCIDR(tt.addr)
|
||||
if err != nil {
|
||||
t.Errorf("Error parsing CIDR: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
lastIP := getLastIPFromNetwork(ipnet, 1)
|
||||
if lastIP != tt.ip {
|
||||
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -53,7 +52,7 @@ type systemdDbusLinkDomainsInput struct {
|
||||
MatchOnly bool
|
||||
}
|
||||
|
||||
func newSystemdDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
||||
func newSystemdDbusConfigurator(wgInterface WGIface) (hostManager, error) {
|
||||
iface, err := net.InterfaceByName(wgInterface.Name())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
14
client/internal/dns/wgiface.go
Normal file
14
client/internal/dns/wgiface.go
Normal file
@@ -0,0 +1,14 @@
|
||||
//go:build !windows
|
||||
|
||||
package dns
|
||||
|
||||
import "github.com/netbirdio/netbird/iface"
|
||||
|
||||
// WGIface defines subset methods of interface required for manager
|
||||
type WGIface interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() iface.PacketFilter
|
||||
GetDevice() *iface.DeviceWrapper
|
||||
}
|
||||
13
client/internal/dns/wgiface_windows.go
Normal file
13
client/internal/dns/wgiface_windows.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package dns
|
||||
|
||||
import "github.com/netbirdio/netbird/iface"
|
||||
|
||||
// WGIface defines subset methods of interface required for manager
|
||||
type WGIface interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() iface.PacketFilter
|
||||
GetDevice() *iface.DeviceWrapper
|
||||
GetInterfaceGUIDString() (string, error)
|
||||
}
|
||||
@@ -20,8 +20,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
@@ -101,7 +101,8 @@ type Engine struct {
|
||||
|
||||
ctx context.Context
|
||||
|
||||
wgInterface *iface.WGIface
|
||||
wgInterface *iface.WGIface
|
||||
wgProxyFactory *wgproxy.Factory
|
||||
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
udpMuxConn io.Closer
|
||||
@@ -132,6 +133,7 @@ func NewEngine(
|
||||
signalClient signal.Client, mgmClient mgm.Client,
|
||||
config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status,
|
||||
) *Engine {
|
||||
|
||||
return &Engine{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
@@ -146,6 +148,7 @@ func NewEngine(
|
||||
networkSerial: 0,
|
||||
sshServerFunc: nbssh.DefaultSSHServer,
|
||||
statusRecorder: statusRecorder,
|
||||
wgProxyFactory: wgproxy.NewFactory(config.WgPort),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -190,23 +193,25 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
|
||||
var routes []*route.Route
|
||||
var dnsCfg *nbdns.Config
|
||||
|
||||
if runtime.GOOS == "android" {
|
||||
routes, dnsCfg, err = e.readInitialSettings()
|
||||
routes, err = e.readInitialSettings()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if e.dnsServer == nil {
|
||||
if e.dnsServer == nil {
|
||||
e.dnsServer = dns.NewDefaultServerPermanentUpstream(e.ctx, e.wgInterface, e.mobileDep.HostDNSAddresses)
|
||||
go e.mobileDep.DnsReadyListener.OnReady()
|
||||
}
|
||||
} else {
|
||||
// todo fix custom address
|
||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, dnsCfg)
|
||||
if err != nil {
|
||||
e.close()
|
||||
return err
|
||||
if e.dnsServer == nil {
|
||||
e.dnsServer, err = dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
|
||||
if err != nil {
|
||||
e.close()
|
||||
return err
|
||||
}
|
||||
}
|
||||
e.dnsServer = dnsServer
|
||||
}
|
||||
|
||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes)
|
||||
@@ -280,7 +285,7 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
for _, p := range peersUpdate {
|
||||
peerPubKey := p.GetWgPubKey()
|
||||
if peerConn, ok := e.peerConns[peerPubKey]; ok {
|
||||
if peerConn.GetConf().ProxyConfig.AllowedIps != strings.Join(p.AllowedIps, ",") {
|
||||
if peerConn.WgConfig().AllowedIps != strings.Join(p.AllowedIps, ",") {
|
||||
modified = append(modified, p)
|
||||
continue
|
||||
}
|
||||
@@ -793,9 +798,7 @@ func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
|
||||
|
||||
// we might have received new STUN and TURN servers meanwhile, so update them
|
||||
e.syncMsgMux.Lock()
|
||||
conf := conn.GetConf()
|
||||
conf.StunTurn = append(e.STUNs, e.TURNs...)
|
||||
conn.UpdateConf(conf)
|
||||
conn.UpdateStunTurn(append(e.STUNs, e.TURNs...))
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
err := conn.Open()
|
||||
@@ -824,9 +827,9 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
|
||||
stunTurn = append(stunTurn, e.STUNs...)
|
||||
stunTurn = append(stunTurn, e.TURNs...)
|
||||
|
||||
proxyConfig := proxy.Config{
|
||||
wgConfig := peer.WgConfig{
|
||||
RemoteKey: pubKey,
|
||||
WgListenAddr: fmt.Sprintf("127.0.0.1:%d", e.config.WgPort),
|
||||
WgListenPort: e.config.WgPort,
|
||||
WgInterface: e.wgInterface,
|
||||
AllowedIps: allowedIPs,
|
||||
PreSharedKey: e.config.PreSharedKey,
|
||||
@@ -843,13 +846,13 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
|
||||
Timeout: timeout,
|
||||
UDPMux: e.udpMux.UDPMuxDefault,
|
||||
UDPMuxSrflx: e.udpMux,
|
||||
ProxyConfig: proxyConfig,
|
||||
WgConfig: wgConfig,
|
||||
LocalWgPort: e.config.WgPort,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
UserspaceBind: e.wgInterface.IsUserspaceBind(),
|
||||
}
|
||||
|
||||
peerConn, err := peer.NewConn(config, e.statusRecorder, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
|
||||
peerConn, err := peer.NewConn(config, e.statusRecorder, e.wgProxyFactory, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1006,6 +1009,10 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
}
|
||||
|
||||
func (e *Engine) close() {
|
||||
if err := e.wgProxyFactory.Free(); err != nil {
|
||||
log.Errorf("failed closing ebpf proxy: %s", err)
|
||||
}
|
||||
|
||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||
if e.wgInterface != nil {
|
||||
if err := e.wgInterface.Close(); err != nil {
|
||||
@@ -1045,14 +1052,13 @@ func (e *Engine) close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
||||
func (e *Engine) readInitialSettings() ([]*route.Route, error) {
|
||||
netMap, err := e.mgmClient.GetNetworkMap()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
routes := toRoutes(netMap.GetRoutes())
|
||||
dnsCfg := toDNSConfig(netMap.GetDNSConfig())
|
||||
return routes, &dnsCfg, nil
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||
|
||||
@@ -367,9 +367,9 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
t.Errorf("expecting Engine.peerConns to contain peer %s", p)
|
||||
}
|
||||
expectedAllowedIPs := strings.Join(p.AllowedIps, ",")
|
||||
if conn.GetConf().ProxyConfig.AllowedIps != expectedAllowedIPs {
|
||||
if conn.WgConfig().AllowedIps != expectedAllowedIPs {
|
||||
t.Errorf("expecting peer %s to have AllowedIPs= %s, got %s", p.GetWgPubKey(),
|
||||
expectedAllowedIPs, conn.GetConf().ProxyConfig.AllowedIps)
|
||||
expectedAllowedIPs, conn.WgConfig().AllowedIps)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
@@ -8,7 +9,9 @@ import (
|
||||
|
||||
// MobileDependency collect all dependencies for mobile platform
|
||||
type MobileDependency struct {
|
||||
TunAdapter iface.TunAdapter
|
||||
IFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
RouteListener routemanager.RouteListener
|
||||
TunAdapter iface.TunAdapter
|
||||
IFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
RouteListener routemanager.RouteListener
|
||||
HostDNSAddresses []string
|
||||
DnsReadyListener dns.ReadyListener
|
||||
}
|
||||
|
||||
@@ -1,286 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OAuthClient is a OAuth client interface for various idp providers
|
||||
type OAuthClient interface {
|
||||
RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
|
||||
WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error)
|
||||
GetClientID(ctx context.Context) string
|
||||
}
|
||||
|
||||
// HTTPClient http client interface for API calls
|
||||
type HTTPClient interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
// DeviceAuthInfo holds information for the OAuth device login flow
|
||||
type DeviceAuthInfo struct {
|
||||
DeviceCode string `json:"device_code"`
|
||||
UserCode string `json:"user_code"`
|
||||
VerificationURI string `json:"verification_uri"`
|
||||
VerificationURIComplete string `json:"verification_uri_complete"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Interval int `json:"interval"`
|
||||
}
|
||||
|
||||
// HostedGrantType grant type for device flow on Hosted
|
||||
const (
|
||||
HostedGrantType = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
HostedRefreshGrant = "refresh_token"
|
||||
)
|
||||
|
||||
// Hosted client
|
||||
type Hosted struct {
|
||||
providerConfig ProviderConfig
|
||||
|
||||
HTTPClient HTTPClient
|
||||
}
|
||||
|
||||
// RequestDeviceCodePayload used for request device code payload for auth0
|
||||
type RequestDeviceCodePayload struct {
|
||||
Audience string `json:"audience"`
|
||||
ClientID string `json:"client_id"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
// TokenRequestPayload used for requesting the auth0 token
|
||||
type TokenRequestPayload struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
DeviceCode string `json:"device_code,omitempty"`
|
||||
ClientID string `json:"client_id"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
}
|
||||
|
||||
// TokenRequestResponse used for parsing Hosted token's response
|
||||
type TokenRequestResponse struct {
|
||||
Error string `json:"error"`
|
||||
ErrorDescription string `json:"error_description"`
|
||||
TokenInfo
|
||||
}
|
||||
|
||||
// Claims used when validating the access token
|
||||
type Claims struct {
|
||||
Audience interface{} `json:"aud"`
|
||||
}
|
||||
|
||||
// TokenInfo holds information of issued access token
|
||||
type TokenInfo struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
UseIDToken bool `json:"-"`
|
||||
}
|
||||
|
||||
// GetTokenToUse returns either the access or id token based on UseIDToken field
|
||||
func (t TokenInfo) GetTokenToUse() string {
|
||||
if t.UseIDToken {
|
||||
return t.IDToken
|
||||
}
|
||||
return t.AccessToken
|
||||
}
|
||||
|
||||
// NewHostedDeviceFlow returns an Hosted OAuth client
|
||||
func NewHostedDeviceFlow(config ProviderConfig) *Hosted {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
return &Hosted{
|
||||
providerConfig: config,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
}
|
||||
|
||||
// GetClientID returns the provider client id
|
||||
func (h *Hosted) GetClientID(ctx context.Context) string {
|
||||
return h.providerConfig.ClientID
|
||||
}
|
||||
|
||||
// RequestDeviceCode requests a device code login flow information from Hosted
|
||||
func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) {
|
||||
form := url.Values{}
|
||||
form.Add("client_id", h.providerConfig.ClientID)
|
||||
form.Add("audience", h.providerConfig.Audience)
|
||||
form.Add("scope", h.providerConfig.Scope)
|
||||
req, err := http.NewRequest("POST", h.providerConfig.DeviceAuthEndpoint,
|
||||
strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return DeviceAuthInfo{}, fmt.Errorf("creating request failed with error: %v", err)
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
res, err := h.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return DeviceAuthInfo{}, fmt.Errorf("doing request failed with error: %v", err)
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return DeviceAuthInfo{}, fmt.Errorf("reading body failed with error: %v", err)
|
||||
}
|
||||
|
||||
if res.StatusCode != 200 {
|
||||
return DeviceAuthInfo{}, fmt.Errorf("request device code returned status %d error: %s", res.StatusCode, string(body))
|
||||
}
|
||||
|
||||
deviceCode := DeviceAuthInfo{}
|
||||
err = json.Unmarshal(body, &deviceCode)
|
||||
if err != nil {
|
||||
return DeviceAuthInfo{}, fmt.Errorf("unmarshaling response failed with error: %v", err)
|
||||
}
|
||||
|
||||
// Fallback to the verification_uri if the IdP doesn't support verification_uri_complete
|
||||
if deviceCode.VerificationURIComplete == "" {
|
||||
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
|
||||
}
|
||||
|
||||
return deviceCode, err
|
||||
}
|
||||
|
||||
func (h *Hosted) requestToken(info DeviceAuthInfo) (TokenRequestResponse, error) {
|
||||
form := url.Values{}
|
||||
form.Add("client_id", h.providerConfig.ClientID)
|
||||
form.Add("grant_type", HostedGrantType)
|
||||
form.Add("device_code", info.DeviceCode)
|
||||
req, err := http.NewRequest("POST", h.providerConfig.TokenEndpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return TokenRequestResponse{}, fmt.Errorf("failed to create request access token: %v", err)
|
||||
}
|
||||
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
res, err := h.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return TokenRequestResponse{}, fmt.Errorf("failed to request access token with error: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := res.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return TokenRequestResponse{}, fmt.Errorf("failed reading access token response body with error: %v", err)
|
||||
}
|
||||
|
||||
if res.StatusCode > 499 {
|
||||
return TokenRequestResponse{}, fmt.Errorf("access token response returned code: %s", string(body))
|
||||
}
|
||||
|
||||
tokenResponse := TokenRequestResponse{}
|
||||
err = json.Unmarshal(body, &tokenResponse)
|
||||
if err != nil {
|
||||
return TokenRequestResponse{}, fmt.Errorf("parsing token response failed with error: %v", err)
|
||||
}
|
||||
|
||||
return tokenResponse, nil
|
||||
}
|
||||
|
||||
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
||||
// it retrieves the access token from Hosted's endpoint and validates it before returning
|
||||
func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error) {
|
||||
interval := time.Duration(info.Interval) * time.Second
|
||||
ticker := time.NewTicker(interval)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return TokenInfo{}, ctx.Err()
|
||||
case <-ticker.C:
|
||||
|
||||
tokenResponse, err := h.requestToken(info)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("parsing token response failed with error: %v", err)
|
||||
}
|
||||
|
||||
if tokenResponse.Error != "" {
|
||||
if tokenResponse.Error == "authorization_pending" {
|
||||
continue
|
||||
} else if tokenResponse.Error == "slow_down" {
|
||||
interval = interval + (3 * time.Second)
|
||||
ticker.Reset(interval)
|
||||
continue
|
||||
}
|
||||
|
||||
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
|
||||
}
|
||||
|
||||
tokenInfo := TokenInfo{
|
||||
AccessToken: tokenResponse.AccessToken,
|
||||
TokenType: tokenResponse.TokenType,
|
||||
RefreshToken: tokenResponse.RefreshToken,
|
||||
IDToken: tokenResponse.IDToken,
|
||||
ExpiresIn: tokenResponse.ExpiresIn,
|
||||
UseIDToken: h.providerConfig.UseIDToken,
|
||||
}
|
||||
|
||||
err = isValidAccessToken(tokenInfo.GetTokenToUse(), h.providerConfig.Audience)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
||||
}
|
||||
|
||||
return tokenInfo, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isValidAccessToken is a simple validation of the access token
|
||||
func isValidAccessToken(token string, audience string) error {
|
||||
if token == "" {
|
||||
return fmt.Errorf("token received is empty")
|
||||
}
|
||||
|
||||
encodedClaims := strings.Split(token, ".")[1]
|
||||
claimsString, err := base64.RawURLEncoding.DecodeString(encodedClaims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
claims := Claims{}
|
||||
err = json.Unmarshal(claimsString, &claims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if claims.Audience == nil {
|
||||
return fmt.Errorf("required token field audience is absent")
|
||||
}
|
||||
|
||||
// Audience claim of JWT can be a string or an array of strings
|
||||
typ := reflect.TypeOf(claims.Audience)
|
||||
switch typ.Kind() {
|
||||
case reflect.String:
|
||||
if claims.Audience == audience {
|
||||
return nil
|
||||
}
|
||||
case reflect.Slice:
|
||||
for _, aud := range claims.Audience.([]interface{}) {
|
||||
if audience == aud {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid JWT token audience field")
|
||||
}
|
||||
@@ -10,9 +10,10 @@ import (
|
||||
|
||||
"github.com/pion/ice/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/iface/bind"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
@@ -23,8 +24,18 @@ import (
|
||||
const (
|
||||
iceKeepAliveDefault = 4 * time.Second
|
||||
iceDisconnectedTimeoutDefault = 6 * time.Second
|
||||
|
||||
defaultWgKeepAlive = 25 * time.Second
|
||||
)
|
||||
|
||||
type WgConfig struct {
|
||||
WgListenPort int
|
||||
RemoteKey string
|
||||
WgInterface *iface.WGIface
|
||||
AllowedIps string
|
||||
PreSharedKey *wgtypes.Key
|
||||
}
|
||||
|
||||
// ConnConfig is a peer Connection configuration
|
||||
type ConnConfig struct {
|
||||
|
||||
@@ -43,7 +54,7 @@ type ConnConfig struct {
|
||||
|
||||
Timeout time.Duration
|
||||
|
||||
ProxyConfig proxy.Config
|
||||
WgConfig WgConfig
|
||||
|
||||
UDPMux ice.UDPMux
|
||||
UDPMuxSrflx ice.UniversalUDPMux
|
||||
@@ -98,7 +109,9 @@ type Conn struct {
|
||||
|
||||
statusRecorder *Status
|
||||
|
||||
proxy proxy.Proxy
|
||||
wgProxyFactory *wgproxy.Factory
|
||||
wgProxy wgproxy.Proxy
|
||||
|
||||
remoteModeCh chan ModeMessage
|
||||
meta meta
|
||||
|
||||
@@ -122,14 +135,19 @@ func (conn *Conn) GetConf() ConnConfig {
|
||||
return conn.config
|
||||
}
|
||||
|
||||
// UpdateConf updates the connection config
|
||||
func (conn *Conn) UpdateConf(conf ConnConfig) {
|
||||
conn.config = conf
|
||||
// WgConfig returns the WireGuard config
|
||||
func (conn *Conn) WgConfig() WgConfig {
|
||||
return conn.config.WgConfig
|
||||
}
|
||||
|
||||
// UpdateStunTurn update the turn and stun addresses
|
||||
func (conn *Conn) UpdateStunTurn(turnStun []*ice.URL) {
|
||||
conn.config.StunTurn = turnStun
|
||||
}
|
||||
|
||||
// NewConn creates a new not opened Conn to the remote peer.
|
||||
// To establish a connection run Conn.Open
|
||||
func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) {
|
||||
func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) {
|
||||
return &Conn{
|
||||
config: config,
|
||||
mu: sync.Mutex{},
|
||||
@@ -139,6 +157,7 @@ func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter
|
||||
remoteAnswerCh: make(chan OfferAnswer),
|
||||
statusRecorder: statusRecorder,
|
||||
remoteModeCh: make(chan ModeMessage, 1),
|
||||
wgProxyFactory: wgProxyFactory,
|
||||
adapter: adapter,
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
}, nil
|
||||
@@ -215,12 +234,12 @@ func (conn *Conn) candidateTypes() []ice.CandidateType {
|
||||
func (conn *Conn) Open() error {
|
||||
log.Debugf("trying to connect to peer %s", conn.config.Key)
|
||||
|
||||
peerState := State{PubKey: conn.config.Key}
|
||||
|
||||
peerState.IP = strings.Split(conn.config.ProxyConfig.AllowedIps, "/")[0]
|
||||
peerState.ConnStatusUpdate = time.Now()
|
||||
peerState.ConnStatus = conn.status
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
|
||||
ConnStatusUpdate: time.Now(),
|
||||
ConnStatus: conn.status,
|
||||
}
|
||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||
if err != nil {
|
||||
log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err)
|
||||
@@ -275,10 +294,11 @@ func (conn *Conn) Open() error {
|
||||
defer conn.notifyDisconnected()
|
||||
conn.mu.Unlock()
|
||||
|
||||
peerState = State{PubKey: conn.config.Key}
|
||||
|
||||
peerState.ConnStatus = conn.status
|
||||
peerState.ConnStatusUpdate = time.Now()
|
||||
peerState = State{
|
||||
PubKey: conn.config.Key,
|
||||
ConnStatus: conn.status,
|
||||
ConnStatusUpdate: time.Now(),
|
||||
}
|
||||
err = conn.statusRecorder.UpdatePeerState(peerState)
|
||||
if err != nil {
|
||||
log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err)
|
||||
@@ -309,19 +329,12 @@ func (conn *Conn) Open() error {
|
||||
remoteWgPort = remoteOfferAnswer.WgListenPort
|
||||
}
|
||||
// the ice connection has been established successfully so we are ready to start the proxy
|
||||
err = conn.startProxy(remoteConn, remoteWgPort)
|
||||
remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if conn.proxy.Type() == proxy.TypeDirectNoProxy {
|
||||
host, _, _ := net.SplitHostPort(remoteConn.LocalAddr().String())
|
||||
rhost, _, _ := net.SplitHostPort(remoteConn.RemoteAddr().String())
|
||||
// direct Wireguard connection
|
||||
log.Infof("directly connected to peer %s [laddr <-> raddr] [%s:%d <-> %s:%d]", conn.config.Key, host, conn.config.LocalWgPort, rhost, remoteWgPort)
|
||||
} else {
|
||||
log.Infof("connected to peer %s [laddr <-> raddr] [%s <-> %s]", conn.config.Key, remoteConn.LocalAddr().String(), remoteConn.RemoteAddr().String())
|
||||
}
|
||||
log.Infof("connected to peer %s, endpoint address: %s", conn.config.Key, remoteAddr.String())
|
||||
|
||||
// wait until connection disconnected or has been closed externally (upper layer, e.g. engine)
|
||||
select {
|
||||
@@ -338,54 +351,60 @@ func isRelayCandidate(candidate ice.Candidate) bool {
|
||||
return candidate.Type() == ice.CandidateTypeRelay
|
||||
}
|
||||
|
||||
// startProxy starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
||||
func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
|
||||
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
||||
func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int) (net.Addr, error) {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
var pair *ice.CandidatePair
|
||||
pair, err := conn.agent.GetSelectedCandidatePair()
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerState := State{PubKey: conn.config.Key}
|
||||
p := conn.getProxy(pair, remoteWgPort)
|
||||
conn.proxy = p
|
||||
err = p.Start(remoteConn)
|
||||
var endpoint net.Addr
|
||||
if isRelayCandidate(pair.Local) {
|
||||
log.Debugf("setup relay connection")
|
||||
conn.wgProxy = conn.wgProxyFactory.GetProxy()
|
||||
endpoint, err = conn.wgProxy.AddTurnConn(remoteConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// To support old version's with direct mode we attempt to punch an additional role with the remote wireguard port
|
||||
go conn.punchRemoteWGPort(pair, remoteWgPort)
|
||||
endpoint = remoteConn.RemoteAddr()
|
||||
}
|
||||
|
||||
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
||||
|
||||
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
|
||||
if err != nil {
|
||||
return err
|
||||
if conn.wgProxy != nil {
|
||||
_ = conn.wgProxy.CloseConn()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn.status = StatusConnected
|
||||
|
||||
peerState.ConnStatus = conn.status
|
||||
peerState.ConnStatusUpdate = time.Now()
|
||||
peerState.LocalIceCandidateType = pair.Local.Type().String()
|
||||
peerState.RemoteIceCandidateType = pair.Remote.Type().String()
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
ConnStatus: conn.status,
|
||||
ConnStatusUpdate: time.Now(),
|
||||
LocalIceCandidateType: pair.Local.Type().String(),
|
||||
RemoteIceCandidateType: pair.Remote.Type().String(),
|
||||
Direct: !isRelayCandidate(pair.Local),
|
||||
}
|
||||
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
||||
peerState.Relayed = true
|
||||
}
|
||||
peerState.Direct = p.Type() == proxy.TypeDirectNoProxy || p.Type() == proxy.TypeNoProxy
|
||||
|
||||
err = conn.statusRecorder.UpdatePeerState(peerState)
|
||||
if err != nil {
|
||||
log.Warnf("unable to save peer's state, got error: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// todo rename this method and the proxy package to something more appropriate
|
||||
func (conn *Conn) getProxy(pair *ice.CandidatePair, remoteWgPort int) proxy.Proxy {
|
||||
if isRelayCandidate(pair.Local) {
|
||||
return proxy.NewWireGuardProxy(conn.config.ProxyConfig)
|
||||
}
|
||||
|
||||
// To support old version's with direct mode we attempt to punch an additional role with the remote wireguard port
|
||||
go conn.punchRemoteWGPort(pair, remoteWgPort)
|
||||
|
||||
return proxy.NewNoProxy(conn.config.ProxyConfig)
|
||||
return endpoint, nil
|
||||
}
|
||||
|
||||
func (conn *Conn) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
||||
@@ -414,22 +433,22 @@ func (conn *Conn) cleanup() error {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
var err1, err2, err3 error
|
||||
if conn.agent != nil {
|
||||
err := conn.agent.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
err1 = conn.agent.Close()
|
||||
if err1 == nil {
|
||||
conn.agent = nil
|
||||
}
|
||||
conn.agent = nil
|
||||
}
|
||||
|
||||
if conn.proxy != nil {
|
||||
err := conn.proxy.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn.proxy = nil
|
||||
if conn.wgProxy != nil {
|
||||
err2 = conn.wgProxy.CloseConn()
|
||||
conn.wgProxy = nil
|
||||
}
|
||||
|
||||
// todo: is it problem if we try to remove a peer what is never existed?
|
||||
err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||
|
||||
if conn.notifyDisconnected != nil {
|
||||
conn.notifyDisconnected()
|
||||
conn.notifyDisconnected = nil
|
||||
@@ -437,10 +456,11 @@ func (conn *Conn) cleanup() error {
|
||||
|
||||
conn.status = StatusDisconnected
|
||||
|
||||
peerState := State{PubKey: conn.config.Key}
|
||||
peerState.ConnStatus = conn.status
|
||||
peerState.ConnStatusUpdate = time.Now()
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
ConnStatus: conn.status,
|
||||
ConnStatusUpdate: time.Now(),
|
||||
}
|
||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||
if err != nil {
|
||||
// pretty common error because by that time Engine can already remove the peer and status won't be available.
|
||||
@@ -449,8 +469,13 @@ func (conn *Conn) cleanup() error {
|
||||
}
|
||||
|
||||
log.Debugf("cleaned up connection to peer %s", conn.config.Key)
|
||||
|
||||
return nil
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
if err2 != nil {
|
||||
return err2
|
||||
}
|
||||
return err3
|
||||
}
|
||||
|
||||
// SetSignalOffer sets a handler function to be triggered by Conn when a new connection offer has to be signalled to the remote peer
|
||||
|
||||
@@ -5,12 +5,11 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
|
||||
"github.com/magiconair/properties/assert"
|
||||
"github.com/pion/ice/v2"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
@@ -20,7 +19,6 @@ var connConf = ConnConfig{
|
||||
StunTurn: []*ice.URL{},
|
||||
InterfaceBlackList: nil,
|
||||
Timeout: time.Second,
|
||||
ProxyConfig: proxy.Config{},
|
||||
LocalWgPort: 51820,
|
||||
}
|
||||
|
||||
@@ -37,7 +35,11 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_GetKey(t *testing.T) {
|
||||
conn, err := NewConn(connConf, nil, nil, nil)
|
||||
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
conn, err := NewConn(connConf, nil, wgProxyFactory, nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -48,8 +50,11 @@ func TestConn_GetKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
||||
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -82,8 +87,11 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
||||
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -115,8 +123,11 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
func TestConn_Status(t *testing.T) {
|
||||
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
||||
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -142,8 +153,11 @@ func TestConn_Status(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_Close(t *testing.T) {
|
||||
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
||||
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
128
client/internal/pkce_auth.go
Normal file
128
client/internal/pkce_auth.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
)
|
||||
|
||||
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
|
||||
type PKCEAuthorizationFlow struct {
|
||||
ProviderConfig PKCEAuthProviderConfig
|
||||
}
|
||||
|
||||
// PKCEAuthProviderConfig has all attributes needed to initiate pkce authorization flow
|
||||
type PKCEAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
||||
AuthorizationEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// RedirectURL handles authorization code from IDP manager
|
||||
RedirectURLs []string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
}
|
||||
|
||||
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
||||
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (PKCEAuthorizationFlow, error) {
|
||||
// validate our peer's Wireguard PRIVATE key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
var mgmTLSEnabled bool
|
||||
if mgmURL.Scheme == "https" {
|
||||
mgmTLSEnabled = true
|
||||
}
|
||||
|
||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
serverKey, err := mgmClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
protoPKCEAuthorizationFlow, err := mgmClient.GetPKCEAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
log.Errorf("failed to retrieve pkce flow: %v", err)
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
authFlow := PKCEAuthorizationFlow{
|
||||
ProviderConfig: PKCEAuthProviderConfig{
|
||||
Audience: protoPKCEAuthorizationFlow.GetProviderConfig().GetAudience(),
|
||||
ClientID: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientID(),
|
||||
ClientSecret: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
||||
TokenEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
||||
AuthorizationEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetAuthorizationEndpoint(),
|
||||
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
|
||||
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
|
||||
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
||||
},
|
||||
}
|
||||
|
||||
err = isPKCEProviderConfigValid(authFlow.ProviderConfig)
|
||||
if err != nil {
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
return authFlow, nil
|
||||
}
|
||||
|
||||
func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error {
|
||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
if config.Audience == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Audience")
|
||||
}
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
||||
}
|
||||
if config.AuthorizationEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Authorization Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "PKCE Auth Scopes")
|
||||
}
|
||||
if config.RedirectURLs == nil {
|
||||
return fmt.Errorf(errorMSGFormat, "PKCE Redirect URLs")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,72 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DummyProxy just sends pings to the RemoteKey peer and reads responses
|
||||
type DummyProxy struct {
|
||||
conn net.Conn
|
||||
remote string
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewDummyProxy(remote string) *DummyProxy {
|
||||
p := &DummyProxy{remote: remote}
|
||||
p.ctx, p.cancel = context.WithCancel(context.Background())
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *DummyProxy) Close() error {
|
||||
p.cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *DummyProxy) Start(remoteConn net.Conn) error {
|
||||
p.conn = remoteConn
|
||||
go func() {
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
default:
|
||||
_, err := p.conn.Read(buf)
|
||||
if err != nil {
|
||||
log.Errorf("error while reading RemoteKey %s proxy %v", p.remote, err)
|
||||
return
|
||||
}
|
||||
//log.Debugf("received %s from %s", string(buf[:n]), p.remote)
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
default:
|
||||
_, err := p.conn.Write([]byte("hello"))
|
||||
//log.Debugf("sent ping to %s", p.remote)
|
||||
if err != nil {
|
||||
log.Errorf("error while writing to RemoteKey %s proxy %v", p.remote, err)
|
||||
return
|
||||
}
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *DummyProxy) Type() Type {
|
||||
return TypeDummy
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
)
|
||||
|
||||
// NoProxy is used just to configure WireGuard without any local proxy in between.
|
||||
// Used when the WireGuard interface is userspace and uses bind.ICEBind
|
||||
type NoProxy struct {
|
||||
config Config
|
||||
}
|
||||
|
||||
// NewNoProxy creates a new NoProxy with a provided config
|
||||
func NewNoProxy(config Config) *NoProxy {
|
||||
return &NoProxy{config: config}
|
||||
}
|
||||
|
||||
// Close removes peer from the WireGuard interface
|
||||
func (p *NoProxy) Close() error {
|
||||
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start just updates WireGuard peer with the remote address
|
||||
func (p *NoProxy) Start(remoteConn net.Conn) error {
|
||||
|
||||
log.Debugf("using NoProxy to connect to peer %s at %s", p.config.RemoteKey, remoteConn.RemoteAddr().String())
|
||||
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
|
||||
addr, p.config.PreSharedKey)
|
||||
}
|
||||
|
||||
func (p *NoProxy) Type() Type {
|
||||
return TypeNoProxy
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
const DefaultWgKeepAlive = 25 * time.Second
|
||||
|
||||
type Type string
|
||||
|
||||
const (
|
||||
TypeDirectNoProxy Type = "DirectNoProxy"
|
||||
TypeWireGuard Type = "WireGuard"
|
||||
TypeDummy Type = "Dummy"
|
||||
TypeNoProxy Type = "NoProxy"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
WgListenAddr string
|
||||
RemoteKey string
|
||||
WgInterface *iface.WGIface
|
||||
AllowedIps string
|
||||
PreSharedKey *wgtypes.Key
|
||||
}
|
||||
|
||||
type Proxy interface {
|
||||
io.Closer
|
||||
// Start creates a local remoteConn and starts proxying data from/to remoteConn
|
||||
Start(remoteConn net.Conn) error
|
||||
Type() Type
|
||||
}
|
||||
@@ -1,128 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
)
|
||||
|
||||
// WireGuardProxy proxies
|
||||
type WireGuardProxy struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
config Config
|
||||
|
||||
remoteConn net.Conn
|
||||
localConn net.Conn
|
||||
}
|
||||
|
||||
func NewWireGuardProxy(config Config) *WireGuardProxy {
|
||||
p := &WireGuardProxy{config: config}
|
||||
p.ctx, p.cancel = context.WithCancel(context.Background())
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *WireGuardProxy) updateEndpoint() error {
|
||||
udpAddr, err := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// add local proxy connection as a Wireguard peer
|
||||
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
|
||||
udpAddr, p.config.PreSharedKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *WireGuardProxy) Start(remoteConn net.Conn) error {
|
||||
p.remoteConn = remoteConn
|
||||
|
||||
var err error
|
||||
p.localConn, err = net.Dial("udp", p.config.WgListenAddr)
|
||||
if err != nil {
|
||||
log.Errorf("failed dialing to local Wireguard port %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = p.updateEndpoint()
|
||||
if err != nil {
|
||||
log.Errorf("error while updating Wireguard peer endpoint [%s] %v", p.config.RemoteKey, err)
|
||||
return err
|
||||
}
|
||||
|
||||
go p.proxyToRemote()
|
||||
go p.proxyToLocal()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *WireGuardProxy) Close() error {
|
||||
p.cancel()
|
||||
if c := p.localConn; c != nil {
|
||||
err := p.localConn.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
|
||||
// blocks
|
||||
func (p *WireGuardProxy) proxyToRemote() {
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
log.Debugf("stopped proxying to remote peer %s due to closed connection", p.config.RemoteKey)
|
||||
return
|
||||
default:
|
||||
n, err := p.localConn.Read(buf)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = p.remoteConn.Write(buf[:n])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
|
||||
// blocks
|
||||
func (p *WireGuardProxy) proxyToLocal() {
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
log.Debugf("stopped proxying from remote peer %s due to closed connection", p.config.RemoteKey)
|
||||
return
|
||||
default:
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = p.localConn.Write(buf[:n])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *WireGuardProxy) Type() Type {
|
||||
return TypeWireGuard
|
||||
}
|
||||
@@ -6,8 +6,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -30,33 +28,13 @@ func genKey(format string, input string) string {
|
||||
|
||||
// NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager
|
||||
func NewFirewall(parentCTX context.Context) firewallManager {
|
||||
ctx, cancel := context.WithCancel(parentCTX)
|
||||
|
||||
if isIptablesSupported() {
|
||||
log.Debugf("iptables is supported")
|
||||
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
|
||||
return &iptablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
ipv4Client: ipv4Client,
|
||||
ipv6Client: ipv6Client,
|
||||
rules: make(map[string]map[string][]string),
|
||||
}
|
||||
manager, err := newNFTablesManager(parentCTX)
|
||||
if err == nil {
|
||||
log.Debugf("nftables firewall manager will be used")
|
||||
return manager
|
||||
}
|
||||
|
||||
log.Debugf("iptables is not supported, using nftables")
|
||||
|
||||
manager := &nftablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
chains: make(map[string]map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
}
|
||||
|
||||
return manager
|
||||
log.Debugf("fallback to iptables firewall manager: %s", err)
|
||||
return newIptablesManager(parentCTX)
|
||||
}
|
||||
|
||||
func getInPair(pair routerPair) routerPair {
|
||||
|
||||
@@ -49,6 +49,32 @@ type iptablesManager struct {
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func newIptablesManager(parentCtx context.Context) *iptablesManager {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err != nil {
|
||||
log.Debugf("failed to initialize iptables for ipv4: %s", err)
|
||||
} else if !isIptablesClientAvailable(ipv4Client) {
|
||||
log.Infof("iptables is missing for ipv4")
|
||||
ipv4Client = nil
|
||||
}
|
||||
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
if err != nil {
|
||||
log.Debugf("failed to initialize iptables for ipv6: %s", err)
|
||||
} else if !isIptablesClientAvailable(ipv6Client) {
|
||||
log.Infof("iptables is missing for ipv6")
|
||||
ipv6Client = nil
|
||||
}
|
||||
|
||||
return &iptablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
ipv4Client: ipv4Client,
|
||||
ipv6Client: ipv6Client,
|
||||
rules: make(map[string]map[string][]string),
|
||||
}
|
||||
}
|
||||
|
||||
// CleanRoutingRules cleans existing iptables resources that we created by the agent
|
||||
func (i *iptablesManager) CleanRoutingRules() {
|
||||
i.mux.Lock()
|
||||
@@ -61,24 +87,28 @@ func (i *iptablesManager) CleanRoutingRules() {
|
||||
|
||||
log.Debug("flushing tables")
|
||||
errMSGFormat := "iptables: failed cleaning %s chain %s,error: %v"
|
||||
err = i.ipv4Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||
if err != nil {
|
||||
log.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
|
||||
if i.ipv4Client != nil {
|
||||
err = i.ipv4Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||
if err != nil {
|
||||
log.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
|
||||
}
|
||||
|
||||
err = i.ipv4Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
|
||||
if err != nil {
|
||||
log.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = i.ipv4Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
|
||||
if err != nil {
|
||||
log.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
|
||||
}
|
||||
if i.ipv6Client != nil {
|
||||
err = i.ipv6Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||
if err != nil {
|
||||
log.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
|
||||
}
|
||||
|
||||
err = i.ipv6Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||
if err != nil {
|
||||
log.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
|
||||
}
|
||||
|
||||
err = i.ipv6Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
|
||||
if err != nil {
|
||||
log.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
|
||||
err = i.ipv6Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
|
||||
if err != nil {
|
||||
log.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("done cleaning up iptables rules")
|
||||
@@ -96,37 +126,41 @@ func (i *iptablesManager) RestoreOrCreateContainers() error {
|
||||
|
||||
errMSGFormat := "iptables: failed creating %s chain %s,error: %v"
|
||||
|
||||
err := createChain(i.ipv4Client, iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
|
||||
if i.ipv4Client != nil {
|
||||
err := createChain(i.ipv4Client, iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
|
||||
}
|
||||
|
||||
err = createChain(i.ipv4Client, iptablesNatTable, iptablesRoutingNatChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
|
||||
}
|
||||
|
||||
err = i.restoreRules(i.ipv4Client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("iptables: error while restoring ipv4 rules: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = createChain(i.ipv4Client, iptablesNatTable, iptablesRoutingNatChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
|
||||
if i.ipv6Client != nil {
|
||||
err := createChain(i.ipv6Client, iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
|
||||
}
|
||||
|
||||
err = createChain(i.ipv6Client, iptablesNatTable, iptablesRoutingNatChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
|
||||
}
|
||||
|
||||
err = i.restoreRules(i.ipv6Client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("iptables: error while restoring ipv6 rules: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = createChain(i.ipv6Client, iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
|
||||
}
|
||||
|
||||
err = createChain(i.ipv6Client, iptablesNatTable, iptablesRoutingNatChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
|
||||
}
|
||||
|
||||
err = i.restoreRules(i.ipv4Client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("iptables: error while restoring ipv4 rules: %v", err)
|
||||
}
|
||||
|
||||
err = i.restoreRules(i.ipv6Client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("iptables: error while restoring ipv6 rules: %v", err)
|
||||
}
|
||||
|
||||
err = i.addJumpRules()
|
||||
err := i.addJumpRules()
|
||||
if err != nil {
|
||||
return fmt.Errorf("iptables: error while creating jump rules: %v", err)
|
||||
}
|
||||
@@ -140,34 +174,38 @@ func (i *iptablesManager) addJumpRules() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rule := append(iptablesDefaultForwardingRule, ipv4Forwarding)
|
||||
err = i.ipv4Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
if i.ipv4Client != nil {
|
||||
rule := append(iptablesDefaultForwardingRule, ipv4Forwarding)
|
||||
|
||||
err = i.ipv4Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[ipv4][ipv4Forwarding] = rule
|
||||
|
||||
rule = append(iptablesDefaultNatRule, ipv4Nat)
|
||||
err = i.ipv4Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[ipv4][ipv4Nat] = rule
|
||||
}
|
||||
|
||||
i.rules[ipv4][ipv4Forwarding] = rule
|
||||
if i.ipv6Client != nil {
|
||||
rule := append(iptablesDefaultForwardingRule, ipv6Forwarding)
|
||||
err = i.ipv6Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[ipv6][ipv6Forwarding] = rule
|
||||
|
||||
rule = append(iptablesDefaultNatRule, ipv4Nat)
|
||||
err = i.ipv4Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
rule = append(iptablesDefaultNatRule, ipv6Nat)
|
||||
err = i.ipv6Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[ipv6][ipv6Nat] = rule
|
||||
}
|
||||
i.rules[ipv4][ipv4Nat] = rule
|
||||
|
||||
rule = append(iptablesDefaultForwardingRule, ipv6Forwarding)
|
||||
err = i.ipv6Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[ipv6][ipv6Forwarding] = rule
|
||||
|
||||
rule = append(iptablesDefaultNatRule, ipv6Nat)
|
||||
err = i.ipv6Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[ipv6][ipv6Nat] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -177,35 +215,39 @@ func (i *iptablesManager) cleanJumpRules() error {
|
||||
var err error
|
||||
errMSGFormat := "iptables: failed cleaning rule from %s chain %s,err: %v"
|
||||
rule, found := i.rules[ipv4][ipv4Forwarding]
|
||||
if found {
|
||||
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Forwarding)
|
||||
err = i.ipv4Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesForwardChain, err)
|
||||
if i.ipv4Client != nil {
|
||||
if found {
|
||||
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Forwarding)
|
||||
err = i.ipv4Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesForwardChain, err)
|
||||
}
|
||||
}
|
||||
rule, found = i.rules[ipv4][ipv4Nat]
|
||||
if found {
|
||||
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Nat)
|
||||
err = i.ipv4Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesPostRoutingChain, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
rule, found = i.rules[ipv4][ipv4Nat]
|
||||
if found {
|
||||
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Nat)
|
||||
err = i.ipv4Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesPostRoutingChain, err)
|
||||
if i.ipv6Client == nil {
|
||||
rule, found = i.rules[ipv6][ipv6Forwarding]
|
||||
if found {
|
||||
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Forwarding)
|
||||
err = i.ipv6Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesForwardChain, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
rule, found = i.rules[ipv6][ipv6Forwarding]
|
||||
if found {
|
||||
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Forwarding)
|
||||
err = i.ipv6Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesForwardChain, err)
|
||||
}
|
||||
}
|
||||
rule, found = i.rules[ipv6][ipv6Nat]
|
||||
if found {
|
||||
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Nat)
|
||||
err = i.ipv6Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesPostRoutingChain, err)
|
||||
rule, found = i.rules[ipv6][ipv6Nat]
|
||||
if found {
|
||||
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Nat)
|
||||
err = i.ipv6Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesPostRoutingChain, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -437,3 +479,8 @@ func getIptablesRuleType(table string) string {
|
||||
}
|
||||
return ruleType
|
||||
}
|
||||
|
||||
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||
_, err := client.ListChains("filter")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
@@ -16,17 +16,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
|
||||
manager := &iptablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
ipv4Client: ipv4Client,
|
||||
ipv6Client: ipv6Client,
|
||||
rules: make(map[string]map[string][]string),
|
||||
}
|
||||
manager := newIptablesManager(context.TODO())
|
||||
|
||||
defer manager.CleanRoutingRules()
|
||||
|
||||
@@ -37,21 +27,21 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
|
||||
require.Len(t, manager.rules[ipv4], 2, "should have created minimal rules for ipv4")
|
||||
|
||||
exists, err := ipv4Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv4][ipv4Forwarding]...)
|
||||
exists, err := manager.ipv4Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv4][ipv4Forwarding]...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesFilterTable, iptablesForwardChain)
|
||||
require.True(t, exists, "forwarding rule should exist")
|
||||
|
||||
exists, err = ipv4Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv4][ipv4Nat]...)
|
||||
exists, err = manager.ipv4Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv4][ipv4Nat]...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesNatTable, iptablesPostRoutingChain)
|
||||
require.True(t, exists, "postrouting rule should exist")
|
||||
|
||||
require.Len(t, manager.rules[ipv6], 2, "should have created minimal rules for ipv6")
|
||||
|
||||
exists, err = ipv6Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv6][ipv6Forwarding]...)
|
||||
exists, err = manager.ipv6Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv6][ipv6Forwarding]...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesFilterTable, iptablesForwardChain)
|
||||
require.True(t, exists, "forwarding rule should exist")
|
||||
|
||||
exists, err = ipv6Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv6][ipv6Nat]...)
|
||||
exists, err = manager.ipv6Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv6][ipv6Nat]...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesNatTable, iptablesPostRoutingChain)
|
||||
require.True(t, exists, "postrouting rule should exist")
|
||||
|
||||
@@ -64,13 +54,13 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
forward4RuleKey := genKey(forwardingFormat, pair.ID)
|
||||
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.source, pair.destination)
|
||||
|
||||
err = ipv4Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward4Rule...)
|
||||
err = manager.ipv4Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward4Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
nat4RuleKey := genKey(natFormat, pair.ID)
|
||||
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.source, pair.destination)
|
||||
|
||||
err = ipv4Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat4Rule...)
|
||||
err = manager.ipv4Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat4Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
pair = routerPair{
|
||||
@@ -83,13 +73,13 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
forward6RuleKey := genKey(forwardingFormat, pair.ID)
|
||||
forward6Rule := genRuleSpec(routingFinalForwardJump, forward6RuleKey, pair.source, pair.destination)
|
||||
|
||||
err = ipv6Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward6Rule...)
|
||||
err = manager.ipv6Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward6Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
nat6RuleKey := genKey(natFormat, pair.ID)
|
||||
nat6Rule := genRuleSpec(routingFinalNatJump, nat6RuleKey, pair.source, pair.destination)
|
||||
|
||||
err = ipv6Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat6Rule...)
|
||||
err = manager.ipv6Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat6Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
delete(manager.rules, ipv4)
|
||||
|
||||
@@ -19,6 +19,9 @@ const (
|
||||
nftablesTable = "netbird-rt"
|
||||
nftablesRoutingForwardingChain = "netbird-rt-fwd"
|
||||
nftablesRoutingNatChain = "netbird-rt-nat"
|
||||
|
||||
userDataAcceptForwardRuleSrc = "frwacceptsrc"
|
||||
userDataAcceptForwardRuleDst = "frwacceptdst"
|
||||
)
|
||||
|
||||
// constants needed to create nftable rules
|
||||
@@ -71,14 +74,41 @@ var (
|
||||
)
|
||||
|
||||
type nftablesManager struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
conn *nftables.Conn
|
||||
tableIPv4 *nftables.Table
|
||||
tableIPv6 *nftables.Table
|
||||
chains map[string]map[string]*nftables.Chain
|
||||
rules map[string]*nftables.Rule
|
||||
mux sync.Mutex
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
conn *nftables.Conn
|
||||
tableIPv4 *nftables.Table
|
||||
tableIPv6 *nftables.Table
|
||||
chains map[string]map[string]*nftables.Chain
|
||||
rules map[string]*nftables.Rule
|
||||
filterTable *nftables.Table
|
||||
defaultForwardRules []*nftables.Rule
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
|
||||
mgr := &nftablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
chains: make(map[string]map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
defaultForwardRules: make([]*nftables.Rule, 2),
|
||||
}
|
||||
|
||||
err := mgr.isSupported()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = mgr.readFilterTable()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
// CleanRoutingRules cleans existing nftables rules from the system
|
||||
@@ -90,6 +120,13 @@ func (n *nftablesManager) CleanRoutingRules() {
|
||||
n.conn.FlushTable(n.tableIPv6)
|
||||
n.conn.FlushTable(n.tableIPv4)
|
||||
}
|
||||
|
||||
if n.defaultForwardRules[0] != nil {
|
||||
err := n.eraseDefaultForwardRule()
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete forward rule: %s", err)
|
||||
}
|
||||
}
|
||||
log.Debugf("flushing tables result in: %v error", n.conn.Flush())
|
||||
}
|
||||
|
||||
@@ -222,6 +259,112 @@ func (n *nftablesManager) refreshRulesMap() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *nftablesManager) readFilterTable() error {
|
||||
tables, err := n.conn.ListTables()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == "filter" {
|
||||
n.filterTable = t
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *nftablesManager) eraseDefaultForwardRule() error {
|
||||
if n.defaultForwardRules[0] == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := n.refreshDefaultForwardRule()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, r := range n.defaultForwardRules {
|
||||
err = n.conn.DelRule(r)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete forward rule (%d): %s", i, err)
|
||||
}
|
||||
n.defaultForwardRules[i] = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *nftablesManager) refreshDefaultForwardRule() error {
|
||||
rules, err := n.conn.GetRules(n.defaultForwardRules[0].Table, n.defaultForwardRules[0].Chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to list rules in forward chain: %s", err)
|
||||
}
|
||||
|
||||
found := false
|
||||
for i, r := range n.defaultForwardRules {
|
||||
for _, rule := range rules {
|
||||
if string(rule.UserData) == string(r.UserData) {
|
||||
n.defaultForwardRules[i] = rule
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return fmt.Errorf("unable to find forward accept rule")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *nftablesManager) acceptForwardRule(sourceNetwork string) error {
|
||||
src := generateCIDRMatcherExpressions("source", sourceNetwork)
|
||||
dst := generateCIDRMatcherExpressions("destination", "0.0.0.0/0")
|
||||
|
||||
var exprs []expr.Any
|
||||
exprs = append(src, append(dst, &expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
})...)
|
||||
|
||||
r := &nftables.Rule{
|
||||
Table: n.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: "FORWARD",
|
||||
Table: n.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: exprs,
|
||||
UserData: []byte(userDataAcceptForwardRuleSrc),
|
||||
}
|
||||
|
||||
n.defaultForwardRules[0] = n.conn.AddRule(r)
|
||||
|
||||
src = generateCIDRMatcherExpressions("source", "0.0.0.0/0")
|
||||
dst = generateCIDRMatcherExpressions("destination", sourceNetwork)
|
||||
|
||||
exprs = append(src, append(dst, &expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
})...)
|
||||
|
||||
r = &nftables.Rule{
|
||||
Table: n.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: "FORWARD",
|
||||
Table: n.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: exprs,
|
||||
UserData: []byte(userDataAcceptForwardRuleDst),
|
||||
}
|
||||
|
||||
n.defaultForwardRules[1] = n.conn.AddRule(r)
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled
|
||||
func (n *nftablesManager) checkOrCreateDefaultForwardingRules() {
|
||||
_, foundIPv4 := n.rules[ipv4Forwarding]
|
||||
@@ -275,6 +418,14 @@ func (n *nftablesManager) InsertRoutingRules(pair routerPair) error {
|
||||
}
|
||||
}
|
||||
|
||||
if n.defaultForwardRules[0] == nil && n.filterTable != nil {
|
||||
err = n.acceptForwardRule(pair.source)
|
||||
if err != nil {
|
||||
log.Errorf("unable to create default forward rule: %s", err)
|
||||
}
|
||||
log.Debugf("default accept forward rule added")
|
||||
}
|
||||
|
||||
err = n.conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
|
||||
@@ -355,6 +506,13 @@ func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(n.rules) == 2 && n.defaultForwardRules[0] != nil {
|
||||
err := n.eraseDefaultForwardRule()
|
||||
if err != nil {
|
||||
log.Errorf("failed to delte default fwd rule: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = n.conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
|
||||
@@ -386,6 +544,14 @@ func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *nftablesManager) isSupported() error {
|
||||
_, err := n.conn.ListChains()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables is not supported: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getPayloadDirectives get expression directives based on ip version and direction
|
||||
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
|
||||
switch {
|
||||
|
||||
@@ -14,21 +14,16 @@ import (
|
||||
|
||||
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
|
||||
manager := &nftablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
chains: make(map[string]map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
manager, err := newNFTablesManager(context.TODO())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create nftables manager: %s", err)
|
||||
}
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
defer manager.CleanRoutingRules()
|
||||
|
||||
err := manager.RestoreOrCreateContainers()
|
||||
err = manager.RestoreOrCreateContainers()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
|
||||
@@ -134,21 +129,16 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
||||
|
||||
for _, testCase := range insertRuleTestCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
|
||||
manager := &nftablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
chains: make(map[string]map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
manager, err := newNFTablesManager(context.TODO())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create nftables manager: %s", err)
|
||||
}
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
defer manager.CleanRoutingRules()
|
||||
|
||||
err := manager.RestoreOrCreateContainers()
|
||||
err = manager.RestoreOrCreateContainers()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.InsertRoutingRules(testCase.inputPair)
|
||||
@@ -239,21 +229,16 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
||||
|
||||
for _, testCase := range removeRuleTestCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
|
||||
manager := &nftablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
chains: make(map[string]map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
manager, err := newNFTablesManager(context.TODO())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create nftables manager: %s", err)
|
||||
}
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
defer manager.CleanRoutingRules()
|
||||
|
||||
err := manager.RestoreOrCreateContainers()
|
||||
err = manager.RestoreOrCreateContainers()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
table := manager.tableIPv4
|
||||
|
||||
8
client/internal/templates/embed.go
Normal file
8
client/internal/templates/embed.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package templates
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
//go:embed pkce-auth-msg.html
|
||||
var PKCEAuthMsgTmpl string
|
||||
87
client/internal/templates/pkce-auth-msg.html
Normal file
87
client/internal/templates/pkce-auth-msg.html
Normal file
@@ -0,0 +1,87 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||||
<style>
|
||||
body {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
min-height: 100vh;
|
||||
margin: 0;
|
||||
background: #f7f8f9;
|
||||
font-family: sans-serif, Arial, Tahoma;
|
||||
}
|
||||
|
||||
.container {
|
||||
width: 100%;
|
||||
background: white;
|
||||
border: 1px solid #e8e9ea;
|
||||
text-align: center;
|
||||
padding: 20px;
|
||||
padding-bottom: 50px;
|
||||
max-width: 550px;
|
||||
margin: 0 10px;
|
||||
}
|
||||
|
||||
.logo {
|
||||
height: 80px;
|
||||
border-bottom: 1px solid #e8e9ea;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.logo img {
|
||||
width: 130px;
|
||||
}
|
||||
|
||||
.content {
|
||||
font-size: 13px;
|
||||
color: #525252;
|
||||
line-height: 18px;
|
||||
padding: 10px 0;
|
||||
}
|
||||
|
||||
.content div {
|
||||
font-size: 18px;
|
||||
line-height: normal;
|
||||
margin-bottom: 5px;
|
||||
color: black;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="logo">
|
||||
<img src="https://img.mailinblue.com/6211297/images/content_library/original/64bd4ce82e1ea753e439b6a2.png">
|
||||
</div>
|
||||
<br>
|
||||
{{ if .Error }}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="50" viewBox="0 0 100 100">
|
||||
<circle cx="50" cy="50" r="45" fill="none" stroke="red" stroke-width="3"/>
|
||||
<path d="M30 30 L70 70 M30 70 L70 30" fill="none" stroke="red" stroke-width="3"/>
|
||||
</svg>
|
||||
<div class="content">
|
||||
<div>
|
||||
Login failed
|
||||
</div>
|
||||
{{ .Error }}.
|
||||
</div>
|
||||
{{ else }}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="50" viewBox="0 0 100 100">
|
||||
<circle cx="50" cy="50" r="45" fill="none" stroke="#5cb85c" stroke-width="3"/>
|
||||
<path d="M30 50 L45 65 L70 35" fill="none" stroke="#5cb85c" stroke-width="5"/>
|
||||
</svg>
|
||||
<div class="content">
|
||||
<div>
|
||||
Login successful
|
||||
</div>
|
||||
Your device is now registered and logged in to NetBird.
|
||||
<br>
|
||||
You can now close this window.
|
||||
</div>
|
||||
{{ end }}
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
120
client/internal/wgproxy/ebpf/bpf_bpfeb.go
Normal file
120
client/internal/wgproxy/ebpf/bpf_bpfeb.go
Normal file
@@ -0,0 +1,120 @@
|
||||
// Code generated by bpf2go; DO NOT EDIT.
|
||||
//go:build arm64be || armbe || mips || mips64 || mips64p32 || ppc64 || s390 || s390x || sparc || sparc64
|
||||
// +build arm64be armbe mips mips64 mips64p32 ppc64 s390 s390x sparc sparc64
|
||||
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/cilium/ebpf"
|
||||
)
|
||||
|
||||
// loadBpf returns the embedded CollectionSpec for bpf.
|
||||
func loadBpf() (*ebpf.CollectionSpec, error) {
|
||||
reader := bytes.NewReader(_BpfBytes)
|
||||
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't load bpf: %w", err)
|
||||
}
|
||||
|
||||
return spec, err
|
||||
}
|
||||
|
||||
// loadBpfObjects loads bpf and converts it into a struct.
|
||||
//
|
||||
// The following types are suitable as obj argument:
|
||||
//
|
||||
// *bpfObjects
|
||||
// *bpfPrograms
|
||||
// *bpfMaps
|
||||
//
|
||||
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
|
||||
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
|
||||
spec, err := loadBpf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return spec.LoadAndAssign(obj, opts)
|
||||
}
|
||||
|
||||
// bpfSpecs contains maps and programs before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfSpecs struct {
|
||||
bpfProgramSpecs
|
||||
bpfMapSpecs
|
||||
}
|
||||
|
||||
// bpfSpecs contains programs before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfProgramSpecs struct {
|
||||
XdpProgFunc *ebpf.ProgramSpec `ebpf:"xdp_prog_func"`
|
||||
}
|
||||
|
||||
// bpfMapSpecs contains maps before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfMapSpecs struct {
|
||||
XdpPortMap *ebpf.MapSpec `ebpf:"xdp_port_map"`
|
||||
}
|
||||
|
||||
// bpfObjects contains all objects after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfObjects struct {
|
||||
bpfPrograms
|
||||
bpfMaps
|
||||
}
|
||||
|
||||
func (o *bpfObjects) Close() error {
|
||||
return _BpfClose(
|
||||
&o.bpfPrograms,
|
||||
&o.bpfMaps,
|
||||
)
|
||||
}
|
||||
|
||||
// bpfMaps contains all maps after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfMaps struct {
|
||||
XdpPortMap *ebpf.Map `ebpf:"xdp_port_map"`
|
||||
}
|
||||
|
||||
func (m *bpfMaps) Close() error {
|
||||
return _BpfClose(
|
||||
m.XdpPortMap,
|
||||
)
|
||||
}
|
||||
|
||||
// bpfPrograms contains all programs after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfPrograms struct {
|
||||
XdpProgFunc *ebpf.Program `ebpf:"xdp_prog_func"`
|
||||
}
|
||||
|
||||
func (p *bpfPrograms) Close() error {
|
||||
return _BpfClose(
|
||||
p.XdpProgFunc,
|
||||
)
|
||||
}
|
||||
|
||||
func _BpfClose(closers ...io.Closer) error {
|
||||
for _, closer := range closers {
|
||||
if err := closer.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Do not access this directly.
|
||||
//
|
||||
//go:embed bpf_bpfeb.o
|
||||
var _BpfBytes []byte
|
||||
BIN
client/internal/wgproxy/ebpf/bpf_bpfeb.o
Normal file
BIN
client/internal/wgproxy/ebpf/bpf_bpfeb.o
Normal file
Binary file not shown.
120
client/internal/wgproxy/ebpf/bpf_bpfel.go
Normal file
120
client/internal/wgproxy/ebpf/bpf_bpfel.go
Normal file
@@ -0,0 +1,120 @@
|
||||
// Code generated by bpf2go; DO NOT EDIT.
|
||||
//go:build 386 || amd64 || amd64p32 || arm || arm64 || mips64le || mips64p32le || mipsle || ppc64le || riscv64
|
||||
// +build 386 amd64 amd64p32 arm arm64 mips64le mips64p32le mipsle ppc64le riscv64
|
||||
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/cilium/ebpf"
|
||||
)
|
||||
|
||||
// loadBpf returns the embedded CollectionSpec for bpf.
|
||||
func loadBpf() (*ebpf.CollectionSpec, error) {
|
||||
reader := bytes.NewReader(_BpfBytes)
|
||||
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't load bpf: %w", err)
|
||||
}
|
||||
|
||||
return spec, err
|
||||
}
|
||||
|
||||
// loadBpfObjects loads bpf and converts it into a struct.
|
||||
//
|
||||
// The following types are suitable as obj argument:
|
||||
//
|
||||
// *bpfObjects
|
||||
// *bpfPrograms
|
||||
// *bpfMaps
|
||||
//
|
||||
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
|
||||
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
|
||||
spec, err := loadBpf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return spec.LoadAndAssign(obj, opts)
|
||||
}
|
||||
|
||||
// bpfSpecs contains maps and programs before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfSpecs struct {
|
||||
bpfProgramSpecs
|
||||
bpfMapSpecs
|
||||
}
|
||||
|
||||
// bpfSpecs contains programs before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfProgramSpecs struct {
|
||||
XdpProgFunc *ebpf.ProgramSpec `ebpf:"xdp_prog_func"`
|
||||
}
|
||||
|
||||
// bpfMapSpecs contains maps before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfMapSpecs struct {
|
||||
XdpPortMap *ebpf.MapSpec `ebpf:"xdp_port_map"`
|
||||
}
|
||||
|
||||
// bpfObjects contains all objects after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfObjects struct {
|
||||
bpfPrograms
|
||||
bpfMaps
|
||||
}
|
||||
|
||||
func (o *bpfObjects) Close() error {
|
||||
return _BpfClose(
|
||||
&o.bpfPrograms,
|
||||
&o.bpfMaps,
|
||||
)
|
||||
}
|
||||
|
||||
// bpfMaps contains all maps after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfMaps struct {
|
||||
XdpPortMap *ebpf.Map `ebpf:"xdp_port_map"`
|
||||
}
|
||||
|
||||
func (m *bpfMaps) Close() error {
|
||||
return _BpfClose(
|
||||
m.XdpPortMap,
|
||||
)
|
||||
}
|
||||
|
||||
// bpfPrograms contains all programs after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfPrograms struct {
|
||||
XdpProgFunc *ebpf.Program `ebpf:"xdp_prog_func"`
|
||||
}
|
||||
|
||||
func (p *bpfPrograms) Close() error {
|
||||
return _BpfClose(
|
||||
p.XdpProgFunc,
|
||||
)
|
||||
}
|
||||
|
||||
func _BpfClose(closers ...io.Closer) error {
|
||||
for _, closer := range closers {
|
||||
if err := closer.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Do not access this directly.
|
||||
//
|
||||
//go:embed bpf_bpfel.o
|
||||
var _BpfBytes []byte
|
||||
BIN
client/internal/wgproxy/ebpf/bpf_bpfel.o
Normal file
BIN
client/internal/wgproxy/ebpf/bpf_bpfel.o
Normal file
Binary file not shown.
84
client/internal/wgproxy/ebpf/loader.go
Normal file
84
client/internal/wgproxy/ebpf/loader.go
Normal file
@@ -0,0 +1,84 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"net"
|
||||
|
||||
"github.com/cilium/ebpf/link"
|
||||
"github.com/cilium/ebpf/rlimit"
|
||||
)
|
||||
|
||||
const (
|
||||
mapKeyProxyPort uint32 = 0
|
||||
mapKeyWgPort uint32 = 1
|
||||
)
|
||||
|
||||
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang-14 bpf src/portreplace.c --
|
||||
|
||||
// EBPF is a wrapper for eBPF program
|
||||
type EBPF struct {
|
||||
link link.Link
|
||||
}
|
||||
|
||||
// NewEBPF create new EBPF instance
|
||||
func NewEBPF() *EBPF {
|
||||
return &EBPF{}
|
||||
}
|
||||
|
||||
// Load load ebpf program
|
||||
func (l *EBPF) Load(proxyPort, wgPort int) error {
|
||||
// it required for Docker
|
||||
err := rlimit.RemoveMemlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ifce, err := net.InterfaceByName("lo")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Load pre-compiled programs into the kernel.
|
||||
objs := bpfObjects{}
|
||||
err = loadBpfObjects(&objs, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
_ = objs.Close()
|
||||
}()
|
||||
|
||||
err = objs.XdpPortMap.Put(mapKeyProxyPort, uint16(proxyPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = objs.XdpPortMap.Put(mapKeyWgPort, uint16(wgPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = objs.XdpPortMap.Close()
|
||||
}()
|
||||
|
||||
l.link, err = link.AttachXDP(link.XDPOptions{
|
||||
Program: objs.XdpProgFunc,
|
||||
Interface: ifce.Index,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Free free ebpf program
|
||||
func (l *EBPF) Free() error {
|
||||
if l.link != nil {
|
||||
return l.link.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
18
client/internal/wgproxy/ebpf/loader_test.go
Normal file
18
client/internal/wgproxy/ebpf/loader_test.go
Normal file
@@ -0,0 +1,18 @@
|
||||
//go:build linux
|
||||
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_newEBPF(t *testing.T) {
|
||||
ebpf := NewEBPF()
|
||||
err := ebpf.Load(1234, 51892)
|
||||
defer func() {
|
||||
_ = ebpf.Free()
|
||||
}()
|
||||
if err != nil {
|
||||
t.Errorf("%s", err)
|
||||
}
|
||||
}
|
||||
90
client/internal/wgproxy/ebpf/src/portreplace.c
Normal file
90
client/internal/wgproxy/ebpf/src/portreplace.c
Normal file
@@ -0,0 +1,90 @@
|
||||
#include <stdbool.h>
|
||||
#include <linux/if_ether.h> // ETH_P_IP
|
||||
#include <linux/udp.h>
|
||||
#include <linux/ip.h>
|
||||
#include <netinet/in.h>
|
||||
#include <linux/bpf.h>
|
||||
#include <bpf/bpf_helpers.h>
|
||||
|
||||
#define bpf_printk(fmt, ...) \
|
||||
({ \
|
||||
char ____fmt[] = fmt; \
|
||||
bpf_trace_printk(____fmt, sizeof(____fmt), ##__VA_ARGS__); \
|
||||
})
|
||||
|
||||
const __u32 map_key_proxy_port = 0;
|
||||
const __u32 map_key_wg_port = 1;
|
||||
|
||||
struct bpf_map_def SEC("maps") xdp_port_map = {
|
||||
.type = BPF_MAP_TYPE_ARRAY,
|
||||
.key_size = sizeof(__u32),
|
||||
.value_size = sizeof(__u16),
|
||||
.max_entries = 10,
|
||||
};
|
||||
|
||||
__u16 proxy_port = 0;
|
||||
__u16 wg_port = 0;
|
||||
|
||||
bool read_port_settings() {
|
||||
__u16 *value;
|
||||
value = bpf_map_lookup_elem(&xdp_port_map, &map_key_proxy_port);
|
||||
if(!value) {
|
||||
return false;
|
||||
}
|
||||
|
||||
proxy_port = *value;
|
||||
|
||||
value = bpf_map_lookup_elem(&xdp_port_map, &map_key_wg_port);
|
||||
if(!value) {
|
||||
return false;
|
||||
}
|
||||
wg_port = *value;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
SEC("xdp")
|
||||
int xdp_prog_func(struct xdp_md *ctx) {
|
||||
if(proxy_port == 0 || wg_port == 0) {
|
||||
if(!read_port_settings()){
|
||||
return XDP_PASS;
|
||||
}
|
||||
bpf_printk("proxy port: %d, wg port: %d", proxy_port, wg_port);
|
||||
}
|
||||
|
||||
void *data = (void *)(long)ctx->data;
|
||||
void *data_end = (void *)(long)ctx->data_end;
|
||||
struct ethhdr *eth = data;
|
||||
struct iphdr *ip = (data + sizeof(struct ethhdr));
|
||||
struct udphdr *udp = (data + sizeof(struct ethhdr) + sizeof(struct iphdr));
|
||||
|
||||
// return early if not enough data
|
||||
if (data + sizeof(struct ethhdr) + sizeof(struct iphdr) + sizeof(struct udphdr) > data_end){
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
// skip non IPv4 packages
|
||||
if (eth->h_proto != htons(ETH_P_IP)) {
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
if (ip->protocol != IPPROTO_UDP) {
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
// 2130706433 = 127.0.0.1
|
||||
if (ip->daddr != htonl(2130706433)) {
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
if (udp->source != htons(wg_port)){
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
__be16 new_src_port = udp->dest;
|
||||
__be16 new_dst_port = htons(proxy_port);
|
||||
udp->dest = new_dst_port;
|
||||
udp->source = new_src_port;
|
||||
return XDP_PASS;
|
||||
}
|
||||
char _license[] SEC("license") = "GPL";
|
||||
20
client/internal/wgproxy/factory.go
Normal file
20
client/internal/wgproxy/factory.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package wgproxy
|
||||
|
||||
type Factory struct {
|
||||
wgPort int
|
||||
ebpfProxy Proxy
|
||||
}
|
||||
|
||||
func (w *Factory) GetProxy() Proxy {
|
||||
if w.ebpfProxy != nil {
|
||||
return w.ebpfProxy
|
||||
}
|
||||
return NewWGUserSpaceProxy(w.wgPort)
|
||||
}
|
||||
|
||||
func (w *Factory) Free() error {
|
||||
if w.ebpfProxy != nil {
|
||||
return w.ebpfProxy.Free()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
21
client/internal/wgproxy/factory_linux.go
Normal file
21
client/internal/wgproxy/factory_linux.go
Normal file
@@ -0,0 +1,21 @@
|
||||
//go:build !android
|
||||
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func NewFactory(wgPort int) *Factory {
|
||||
f := &Factory{wgPort: wgPort}
|
||||
|
||||
ebpfProxy := NewWGEBPFProxy(wgPort)
|
||||
err := ebpfProxy.Listen()
|
||||
if err != nil {
|
||||
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
|
||||
return f
|
||||
}
|
||||
|
||||
f.ebpfProxy = ebpfProxy
|
||||
return f
|
||||
}
|
||||
7
client/internal/wgproxy/factory_nonlinux.go
Normal file
7
client/internal/wgproxy/factory_nonlinux.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package wgproxy
|
||||
|
||||
func NewFactory(wgPort int) *Factory {
|
||||
return &Factory{wgPort: wgPort}
|
||||
}
|
||||
32
client/internal/wgproxy/portlookup.go
Normal file
32
client/internal/wgproxy/portlookup.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
portRangeStart = 3128
|
||||
portRangeEnd = 3228
|
||||
)
|
||||
|
||||
type portLookup struct {
|
||||
}
|
||||
|
||||
func (pl portLookup) searchFreePort() (int, error) {
|
||||
for i := portRangeStart; i <= portRangeEnd; i++ {
|
||||
if pl.tryToBind(i) == nil {
|
||||
return i, nil
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("failed to bind free port for eBPF proxy")
|
||||
}
|
||||
|
||||
func (pl portLookup) tryToBind(port int) error {
|
||||
l, err := net.ListenPacket("udp", fmt.Sprintf(":%d", port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = l.Close()
|
||||
return nil
|
||||
}
|
||||
42
client/internal/wgproxy/portlookup_test.go
Normal file
42
client/internal/wgproxy/portlookup_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_portLookup_searchFreePort(t *testing.T) {
|
||||
pl := portLookup{}
|
||||
_, err := pl.searchFreePort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_portLookup_on_allocated(t *testing.T) {
|
||||
pl := portLookup{}
|
||||
|
||||
allocatedPort, err := allocatePort(portRangeStart)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer allocatedPort.Close()
|
||||
|
||||
fp, err := pl.searchFreePort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if fp != (portRangeStart + 1) {
|
||||
t.Errorf("invalid free port, expected: %d, got: %d", portRangeStart+1, fp)
|
||||
}
|
||||
}
|
||||
|
||||
func allocatePort(port int) (net.PacketConn, error) {
|
||||
c, err := net.ListenPacket("udp", fmt.Sprintf(":%d", port))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, err
|
||||
}
|
||||
12
client/internal/wgproxy/proxy.go
Normal file
12
client/internal/wgproxy/proxy.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// Proxy is a transfer layer between the Turn connection and the WireGuard
|
||||
type Proxy interface {
|
||||
AddTurnConn(urnConn net.Conn) (net.Addr, error)
|
||||
CloseConn() error
|
||||
Free() error
|
||||
}
|
||||
257
client/internal/wgproxy/proxy_ebpf.go
Normal file
257
client/internal/wgproxy/proxy_ebpf.go
Normal file
@@ -0,0 +1,257 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
ebpf2 "github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
|
||||
)
|
||||
|
||||
// WGEBPFProxy definition for proxy with EBPF support
|
||||
type WGEBPFProxy struct {
|
||||
ebpf *ebpf2.EBPF
|
||||
lastUsedPort uint16
|
||||
localWGListenPort int
|
||||
|
||||
turnConnStore map[uint16]net.Conn
|
||||
turnConnMutex sync.Mutex
|
||||
|
||||
rawConn net.PacketConn
|
||||
conn *net.UDPConn
|
||||
}
|
||||
|
||||
// NewWGEBPFProxy create new WGEBPFProxy instance
|
||||
func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
|
||||
log.Debugf("instantiate ebpf proxy")
|
||||
wgProxy := &WGEBPFProxy{
|
||||
localWGListenPort: wgPort,
|
||||
ebpf: ebpf2.NewEBPF(),
|
||||
lastUsedPort: 0,
|
||||
turnConnStore: make(map[uint16]net.Conn),
|
||||
}
|
||||
return wgProxy
|
||||
}
|
||||
|
||||
// Listen load ebpf program and listen the proxy
|
||||
func (p *WGEBPFProxy) Listen() error {
|
||||
pl := portLookup{}
|
||||
wgPorxyPort, err := pl.searchFreePort()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.rawConn, err = p.prepareSenderRawSocket()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = p.ebpf.Load(wgPorxyPort, p.localWGListenPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
addr := net.UDPAddr{
|
||||
Port: wgPorxyPort,
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
}
|
||||
|
||||
p.conn, err = net.ListenUDP("udp", &addr)
|
||||
if err != nil {
|
||||
cErr := p.Free()
|
||||
if cErr != nil {
|
||||
log.Errorf("failed to close the wgproxy: %s", cErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
go p.proxyToRemote()
|
||||
log.Infof("local wg proxy listening on: %d", wgPorxyPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddTurnConn add new turn connection for the proxy
|
||||
func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
|
||||
wgEndpointPort, err := p.storeTurnConn(turnConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go p.proxyToLocal(wgEndpointPort, turnConn)
|
||||
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
|
||||
|
||||
wgEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: int(wgEndpointPort),
|
||||
}
|
||||
return wgEndpoint, nil
|
||||
}
|
||||
|
||||
// CloseConn doing nothing because this type of proxy implementation does not store the connection
|
||||
func (p *WGEBPFProxy) CloseConn() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Free resources
|
||||
func (p *WGEBPFProxy) Free() error {
|
||||
log.Debugf("free up ebpf wg proxy")
|
||||
var err1, err2, err3 error
|
||||
if p.conn != nil {
|
||||
err1 = p.conn.Close()
|
||||
}
|
||||
|
||||
err2 = p.ebpf.Free()
|
||||
if p.rawConn != nil {
|
||||
err3 = p.rawConn.Close()
|
||||
}
|
||||
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
|
||||
if err2 != nil {
|
||||
return err2
|
||||
}
|
||||
|
||||
return err3
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
n, err := remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
|
||||
}
|
||||
p.removeTurnConn(endpointPort)
|
||||
return
|
||||
}
|
||||
err = p.sendPkg(buf[:n], endpointPort)
|
||||
if err != nil {
|
||||
log.Errorf("failed to write out turn pkg to local conn: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
|
||||
func (p *WGEBPFProxy) proxyToRemote() {
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
n, addr, err := p.conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
log.Errorf("failed to read UDP pkg from WG: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.turnConnMutex.Lock()
|
||||
conn, ok := p.turnConnStore[uint16(addr.Port)]
|
||||
size := len(p.turnConnStore)
|
||||
p.turnConnMutex.Unlock()
|
||||
if !ok {
|
||||
log.Errorf("turn conn not found by port: %d", addr.Port)
|
||||
log.Debugf("conn store size: %d", size)
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = conn.Write(buf[:n])
|
||||
if err != nil {
|
||||
log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
|
||||
p.turnConnMutex.Lock()
|
||||
defer p.turnConnMutex.Unlock()
|
||||
|
||||
np, err := p.nextFreePort()
|
||||
if err != nil {
|
||||
return np, err
|
||||
}
|
||||
p.turnConnStore[np] = turnConn
|
||||
return np, nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) {
|
||||
log.Tracef("remove turn conn from store by port: %d", turnConnID)
|
||||
p.turnConnMutex.Lock()
|
||||
defer p.turnConnMutex.Unlock()
|
||||
delete(p.turnConnStore, turnConnID)
|
||||
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) nextFreePort() (uint16, error) {
|
||||
if len(p.turnConnStore) == 65535 {
|
||||
return 0, fmt.Errorf("reached maximum turn connection numbers")
|
||||
}
|
||||
generatePort:
|
||||
if p.lastUsedPort == 65535 {
|
||||
p.lastUsedPort = 1
|
||||
} else {
|
||||
p.lastUsedPort++
|
||||
}
|
||||
|
||||
if _, ok := p.turnConnStore[p.lastUsedPort]; ok {
|
||||
goto generatePort
|
||||
}
|
||||
return p.lastUsedPort, nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
|
||||
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return net.FilePacketConn(os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)))
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error {
|
||||
localhost := net.ParseIP("127.0.0.1")
|
||||
|
||||
payload := gopacket.Payload(data)
|
||||
ipH := &layers.IPv4{
|
||||
DstIP: localhost,
|
||||
SrcIP: localhost,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
udpH := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(port),
|
||||
DstPort: layers.UDPPort(p.localWGListenPort),
|
||||
}
|
||||
|
||||
err := udpH.SetNetworkLayerForChecksum(ipH)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
layerBuffer := gopacket.NewSerializeBuffer()
|
||||
|
||||
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost})
|
||||
return err
|
||||
}
|
||||
56
client/internal/wgproxy/proxy_ebpf_test.go
Normal file
56
client/internal/wgproxy/proxy_ebpf_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWGEBPFProxy_connStore(t *testing.T) {
|
||||
wgProxy := NewWGEBPFProxy(1)
|
||||
|
||||
p, _ := wgProxy.storeTurnConn(nil)
|
||||
if p != 1 {
|
||||
t.Errorf("invalid initial port: %d", wgProxy.lastUsedPort)
|
||||
}
|
||||
|
||||
numOfConns := 10
|
||||
for i := 0; i < numOfConns; i++ {
|
||||
p, _ = wgProxy.storeTurnConn(nil)
|
||||
}
|
||||
if p != uint16(numOfConns)+1 {
|
||||
t.Errorf("invalid last used port: %d, expected: %d", p, numOfConns+1)
|
||||
}
|
||||
if len(wgProxy.turnConnStore) != numOfConns+1 {
|
||||
t.Errorf("invalid store size: %d, expected: %d", len(wgProxy.turnConnStore), numOfConns+1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
|
||||
wgProxy := NewWGEBPFProxy(1)
|
||||
|
||||
_, _ = wgProxy.storeTurnConn(nil)
|
||||
wgProxy.lastUsedPort = 65535
|
||||
p, _ := wgProxy.storeTurnConn(nil)
|
||||
|
||||
if len(wgProxy.turnConnStore) != 2 {
|
||||
t.Errorf("invalid store size: %d, expected: %d", len(wgProxy.turnConnStore), 2)
|
||||
}
|
||||
|
||||
if p != 2 {
|
||||
t.Errorf("invalid last used port: %d, expected: %d", p, 2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) {
|
||||
wgProxy := NewWGEBPFProxy(1)
|
||||
|
||||
for i := 0; i < 65535; i++ {
|
||||
_, _ = wgProxy.storeTurnConn(nil)
|
||||
}
|
||||
|
||||
_, err := wgProxy.storeTurnConn(nil)
|
||||
if err == nil {
|
||||
t.Errorf("invalid turn conn store calculation")
|
||||
}
|
||||
}
|
||||
107
client/internal/wgproxy/proxy_userspace.go
Normal file
107
client/internal/wgproxy/proxy_userspace.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// WGUserSpaceProxy proxies
|
||||
type WGUserSpaceProxy struct {
|
||||
localWGListenPort int
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
remoteConn net.Conn
|
||||
localConn net.Conn
|
||||
}
|
||||
|
||||
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy
|
||||
func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
|
||||
log.Debugf("instantiate new userspace proxy")
|
||||
p := &WGUserSpaceProxy{
|
||||
localWGListenPort: wgPort,
|
||||
}
|
||||
p.ctx, p.cancel = context.WithCancel(context.Background())
|
||||
return p
|
||||
}
|
||||
|
||||
// AddTurnConn start the proxy with the given remote conn
|
||||
func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) {
|
||||
p.remoteConn = remoteConn
|
||||
|
||||
var err error
|
||||
p.localConn, err = net.Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||
if err != nil {
|
||||
log.Errorf("failed dialing to local Wireguard port %s", err)
|
||||
return nil, err
|
||||
}
|
||||
log.Debugf("add turn conn: %s", remoteConn.RemoteAddr())
|
||||
|
||||
go p.proxyToRemote()
|
||||
go p.proxyToLocal()
|
||||
|
||||
return p.localConn.LocalAddr(), err
|
||||
}
|
||||
|
||||
// CloseConn close the localConn
|
||||
func (p *WGUserSpaceProxy) CloseConn() error {
|
||||
p.cancel()
|
||||
if p.localConn == nil {
|
||||
return nil
|
||||
}
|
||||
return p.localConn.Close()
|
||||
}
|
||||
|
||||
// Free doing nothing because this implementation of proxy does not have global state
|
||||
func (p *WGUserSpaceProxy) Free() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
|
||||
// blocks
|
||||
func (p *WGUserSpaceProxy) proxyToRemote() {
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
default:
|
||||
n, err := p.localConn.Read(buf)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = p.remoteConn.Write(buf[:n])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
|
||||
// blocks
|
||||
func (p *WGUserSpaceProxy) proxyToLocal() {
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
default:
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = p.localConn.Write(buf[:n])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -38,8 +39,8 @@ type Server struct {
|
||||
|
||||
type oauthAuthFlow struct {
|
||||
expiresAt time.Time
|
||||
client internal.OAuthClient
|
||||
info internal.DeviceAuthInfo
|
||||
flow auth.OAuthFlow
|
||||
info auth.AuthFlowInfo
|
||||
waitCancel context.CancelFunc
|
||||
}
|
||||
|
||||
@@ -102,7 +103,7 @@ func (s *Server) Start() error {
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := internal.RunClient(ctx, config, s.statusRecorder, nil, nil, nil); err != nil {
|
||||
if err := internal.RunClient(ctx, config, s.statusRecorder); err != nil {
|
||||
log.Errorf("init connections: %v", err)
|
||||
}
|
||||
}()
|
||||
@@ -206,28 +207,15 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
state.Set(internal.StatusConnecting)
|
||||
|
||||
if msg.SetupKey == "" {
|
||||
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config)
|
||||
if err != nil {
|
||||
state.Set(internal.StatusLoginFailed)
|
||||
s, ok := gstatus.FromError(err)
|
||||
if ok && s.Code() == codes.NotFound {
|
||||
return nil, gstatus.Errorf(codes.NotFound, "no SSO provider returned from management. "+
|
||||
"If you are using hosting Netbird see documentation at "+
|
||||
"https://github.com/netbirdio/netbird/tree/main/management for details")
|
||||
} else if ok && s.Code() == codes.Unimplemented {
|
||||
return nil, gstatus.Errorf(codes.Unimplemented, "the management server, %s, does not support SSO providers, "+
|
||||
"please update your server or use Setup Keys to login", config.ManagementURL)
|
||||
} else {
|
||||
log.Errorf("getting device authorization flow info failed with error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig)
|
||||
|
||||
if s.oauthAuthFlow.client != nil && s.oauthAuthFlow.client.GetClientID(ctx) == hostedClient.GetClientID(context.TODO()) {
|
||||
if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(context.TODO()) {
|
||||
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
|
||||
log.Debugf("using previous device flow info")
|
||||
log.Debugf("using previous oauth flow info")
|
||||
return &proto.LoginResponse{
|
||||
NeedsSSOLogin: true,
|
||||
VerificationURI: s.oauthAuthFlow.info.VerificationURI,
|
||||
@@ -242,25 +230,25 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
}
|
||||
}
|
||||
|
||||
deviceAuthInfo, err := hostedClient.RequestDeviceCode(context.TODO())
|
||||
authInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
||||
if err != nil {
|
||||
log.Errorf("getting a request device code failed: %v", err)
|
||||
log.Errorf("getting a request OAuth flow failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
s.oauthAuthFlow.client = hostedClient
|
||||
s.oauthAuthFlow.info = deviceAuthInfo
|
||||
s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(deviceAuthInfo.ExpiresIn) * time.Second)
|
||||
s.oauthAuthFlow.flow = oAuthFlow
|
||||
s.oauthAuthFlow.info = authInfo
|
||||
s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(authInfo.ExpiresIn) * time.Second)
|
||||
s.mutex.Unlock()
|
||||
|
||||
state.Set(internal.StatusNeedsLogin)
|
||||
|
||||
return &proto.LoginResponse{
|
||||
NeedsSSOLogin: true,
|
||||
VerificationURI: deviceAuthInfo.VerificationURI,
|
||||
VerificationURIComplete: deviceAuthInfo.VerificationURIComplete,
|
||||
UserCode: deviceAuthInfo.UserCode,
|
||||
VerificationURI: authInfo.VerificationURI,
|
||||
VerificationURIComplete: authInfo.VerificationURIComplete,
|
||||
UserCode: authInfo.UserCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -289,8 +277,8 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
||||
s.actCancel = cancel
|
||||
s.mutex.Unlock()
|
||||
|
||||
if s.oauthAuthFlow.client == nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "oauth client is not initialized")
|
||||
if s.oauthAuthFlow.flow == nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "oauth flow is not initialized")
|
||||
}
|
||||
|
||||
state := internal.CtxGetState(ctx)
|
||||
@@ -304,10 +292,10 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
||||
state.Set(internal.StatusConnecting)
|
||||
|
||||
s.mutex.Lock()
|
||||
deviceAuthInfo := s.oauthAuthFlow.info
|
||||
flowInfo := s.oauthAuthFlow.info
|
||||
s.mutex.Unlock()
|
||||
|
||||
if deviceAuthInfo.UserCode != msg.UserCode {
|
||||
if flowInfo.UserCode != msg.UserCode {
|
||||
state.Set(internal.StatusLoginFailed)
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "sso user code is invalid")
|
||||
}
|
||||
@@ -324,7 +312,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
||||
s.oauthAuthFlow.waitCancel = cancel
|
||||
s.mutex.Unlock()
|
||||
|
||||
tokenInfo, err := s.oauthAuthFlow.client.WaitToken(waitCTX, deviceAuthInfo)
|
||||
tokenInfo, err := s.oauthAuthFlow.flow.WaitToken(waitCTX, flowInfo)
|
||||
if err != nil {
|
||||
if err == context.Canceled {
|
||||
return nil, nil
|
||||
@@ -391,7 +379,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := internal.RunClient(ctx, s.config, s.statusRecorder, nil, nil, nil); err != nil {
|
||||
if err := internal.RunClient(ctx, s.config, s.statusRecorder); err != nil {
|
||||
log.Errorf("run client connection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,9 +2,6 @@ package ssh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/creack/pty"
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
@@ -13,11 +10,22 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/creack/pty"
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
|
||||
const DefaultSSHPort = 44338
|
||||
|
||||
// TerminalTimeout is the timeout for terminal session to be ready
|
||||
const TerminalTimeout = 10 * time.Second
|
||||
|
||||
// TerminalBackoffDelay is the delay between terminal session readiness checks
|
||||
const TerminalBackoffDelay = 500 * time.Millisecond
|
||||
|
||||
// DefaultSSHServer is a function that creates DefaultServer
|
||||
func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) {
|
||||
return newDefaultServer(hostKeyPEM, addr)
|
||||
@@ -137,6 +145,8 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
|
||||
}
|
||||
}()
|
||||
|
||||
log.Infof("Establishing SSH session for %s from host %s", session.User(), session.RemoteAddr().String())
|
||||
|
||||
localUser, err := userNameLookup(session.User())
|
||||
if err != nil {
|
||||
_, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint
|
||||
@@ -172,6 +182,7 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("Login command: %s", cmd.String())
|
||||
file, err := pty.Start(cmd)
|
||||
if err != nil {
|
||||
log.Errorf("failed starting SSH server %v", err)
|
||||
@@ -199,6 +210,7 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
|
||||
return
|
||||
}
|
||||
}
|
||||
log.Debugf("SSH session ended")
|
||||
}
|
||||
|
||||
func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
|
||||
@@ -206,17 +218,29 @@ func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
|
||||
// stdin
|
||||
_, err := io.Copy(file, session)
|
||||
if err != nil {
|
||||
_ = session.Exit(1)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
// stdout
|
||||
_, err := io.Copy(session, file)
|
||||
if err != nil {
|
||||
// AWS Linux 2 machines need some time to open the terminal so we need to wait for it
|
||||
timer := time.NewTimer(TerminalTimeout)
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
_, _ = session.Write([]byte("Reached timeout while opening connection\n"))
|
||||
_ = session.Exit(1)
|
||||
return
|
||||
default:
|
||||
// stdout
|
||||
writtenBytes, err := io.Copy(session, file)
|
||||
if err != nil && writtenBytes != 0 {
|
||||
_ = session.Exit(0)
|
||||
return
|
||||
}
|
||||
time.Sleep(TerminalBackoffDelay)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts SSH server. Blocking
|
||||
|
||||
3
go.mod
3
go.mod
@@ -30,6 +30,7 @@ require (
|
||||
require (
|
||||
fyne.io/fyne/v2 v2.1.4
|
||||
github.com/c-robinson/iplib v1.0.3
|
||||
github.com/cilium/ebpf v0.10.0
|
||||
github.com/coreos/go-iptables v0.6.0
|
||||
github.com/creack/pty v1.1.18
|
||||
github.com/eko/gocache/v3 v3.1.1
|
||||
@@ -48,6 +49,7 @@ require (
|
||||
github.com/mdlayher/socket v0.4.0
|
||||
github.com/miekg/dns v1.1.43
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/nadoo/ipset v0.5.0
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/pion/logging v0.2.2
|
||||
@@ -124,7 +126,6 @@ require (
|
||||
github.com/prometheus/client_model v0.3.0 // indirect
|
||||
github.com/prometheus/common v0.37.0 // indirect
|
||||
github.com/prometheus/procfs v0.8.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.8.0 // indirect
|
||||
github.com/spf13/cast v1.5.0 // indirect
|
||||
github.com/srwiley/oksvg v0.0.0-20200311192757-870daf9aa564 // indirect
|
||||
github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect
|
||||
|
||||
12
go.sum
12
go.sum
@@ -101,6 +101,8 @@ github.com/cilium/ebpf v0.0.0-20200110133405-4032b1d8aae3/go.mod h1:MA5e5Lr8slmE
|
||||
github.com/cilium/ebpf v0.4.0/go.mod h1:4tRaxcgiL706VnOzHOdBlY8IEAIdxINsQBcU4xJJXRs=
|
||||
github.com/cilium/ebpf v0.5.0/go.mod h1:4tRaxcgiL706VnOzHOdBlY8IEAIdxINsQBcU4xJJXRs=
|
||||
github.com/cilium/ebpf v0.7.0/go.mod h1:/oI2+1shJiTGAMgl6/RgJr36Eo1jzrRcAWbcXO2usCA=
|
||||
github.com/cilium/ebpf v0.10.0 h1:nk5HPMeoBXtOzbkZBWym+ZWq1GIiHUsBFXxwewXAHLQ=
|
||||
github.com/cilium/ebpf v0.10.0/go.mod h1:DPiVdY/kT534dgc9ERmvP8mWA+9gvwgKfRvk4nNWnoE=
|
||||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
|
||||
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
|
||||
@@ -177,7 +179,7 @@ github.com/evanphx/json-patch v4.12.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQL
|
||||
github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw=
|
||||
github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g=
|
||||
github.com/frankban/quicktest v1.11.3/go.mod h1:wRf/ReqHper53s+kmmSZizM8NamnL3IM0I9ntUbOk+k=
|
||||
github.com/frankban/quicktest v1.14.3 h1:FJKSZTDHjyhriyC81FLQ0LY93eSai0ZyR/ZIkd3ZUKE=
|
||||
github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0XL9UY=
|
||||
github.com/fredbi/uri v0.0.0-20181227131451-3dcfdacbaaf3 h1:FDqhDm7pcsLhhWl1QtD8vlzI4mm59llRvNzrFg6/LAA=
|
||||
github.com/fredbi/uri v0.0.0-20181227131451-3dcfdacbaaf3/go.mod h1:CzM2G82Q9BDUvMTGHnXf/6OExw/Dz2ivDj48nVg7Lg8=
|
||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||
@@ -419,7 +421,7 @@ github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFB
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
@@ -485,6 +487,8 @@ github.com/munnerz/goautoneg v0.0.0-20120707110453-a547fc61f48d/go.mod h1:+n7T8m
|
||||
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
|
||||
github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc=
|
||||
github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
|
||||
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
|
||||
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/systray v0.0.0-20221012095658-dc8eda872c0c h1:wK/s4nyZj/GF/kFJQjX6nqNfE0G3gcqd6hhnPCyp4sw=
|
||||
@@ -550,7 +554,6 @@ github.com/pion/turn/v2 v2.1.0 h1:5wGHSgGhJhP/RpabkUb/T9PdsAjkGLS6toYz5HNzoSI=
|
||||
github.com/pion/turn/v2 v2.1.0/go.mod h1:yrT5XbXSGX1VFSF31A3c1kCNB5bBZgk/uu5LET162qs=
|
||||
github.com/pion/udp/v2 v2.0.1 h1:xP0z6WNux1zWEjhC7onRA3EwwSliXqu1ElUZAQhUP54=
|
||||
github.com/pion/udp/v2 v2.0.1/go.mod h1:B7uvTMP00lzWdyMr/1PVZXtV3wpPIxBRd4Wl6AksXn8=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
@@ -594,8 +597,7 @@ github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40T
|
||||
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
|
||||
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
||||
github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so=
|
||||
github.com/rs/cors v1.8.0/go.mod h1:EBwu+T5AvHOcXwvZIkQFjUN6s8Czyqw12GL/Y0tUyRM=
|
||||
github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4=
|
||||
|
||||
@@ -82,7 +82,7 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
|
||||
default:
|
||||
_, a, err := m.params.UDPConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
log.Errorf("error while reading packet %s", err)
|
||||
log.Errorf("error while reading packet: %s", err)
|
||||
continue
|
||||
}
|
||||
msg := &stun.Message{
|
||||
|
||||
@@ -74,7 +74,7 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.D
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
log.Debugf("updating interface %s peer %s: endpoint %s ", w.tun.DeviceName(), peerKey, endpoint)
|
||||
log.Debugf("updating interface %s peer %s, endpoint %s ", w.tun.DeviceName(), peerKey, endpoint)
|
||||
return w.configurer.updatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/pion/transport/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
@@ -34,7 +33,6 @@ func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter
|
||||
func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
log.Debugf("create WireGuard interface %s", w.tun.DeviceName())
|
||||
return w.tun.Create(mIFaceArgs)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/pion/transport/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
@@ -38,6 +37,5 @@ func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error {
|
||||
func (w *WGIface) Create() error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
log.Debugf("create WireGuard interface %s", w.tun.DeviceName())
|
||||
return w.tun.Create()
|
||||
}
|
||||
|
||||
@@ -34,6 +34,7 @@ func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter, transportNe
|
||||
}
|
||||
|
||||
func (t *tunDevice) Create(mIFaceArgs MobileIFaceArguments) error {
|
||||
log.Info("create tun interface")
|
||||
var err error
|
||||
routesString := t.routesToString(mIFaceArgs.Routes)
|
||||
t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, mIFaceArgs.Dns, routesString)
|
||||
|
||||
@@ -12,14 +12,14 @@ import (
|
||||
|
||||
func (c *tunDevice) Create() error {
|
||||
if WireGuardModuleIsLoaded() {
|
||||
log.Info("using kernel WireGuard")
|
||||
log.Infof("create tun interface with kernel WireGuard support: %s", c.DeviceName())
|
||||
return c.createWithKernel()
|
||||
}
|
||||
|
||||
if !tunModuleIsLoaded() {
|
||||
return fmt.Errorf("couldn't check or load tun module")
|
||||
}
|
||||
log.Info("using userspace WireGuard")
|
||||
log.Infof("create tun interface with userspace WireGuard support: %s", c.DeviceName())
|
||||
var err error
|
||||
c.netInterface, err = c.createWithUserspace()
|
||||
if err != nil {
|
||||
|
||||
@@ -56,7 +56,7 @@ func (c *tunDevice) createWithUserspace() (NetInterface, error) {
|
||||
c.wrapper = newDeviceWrapper(tunIface)
|
||||
|
||||
// We need to create a wireguard-go device and listen to configuration requests
|
||||
tunDev := device.NewDevice(tunIface, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] "))
|
||||
tunDev := device.NewDevice(c.wrapper, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] "))
|
||||
err = tunDev.Up()
|
||||
if err != nil {
|
||||
_ = tunIface.Close()
|
||||
|
||||
@@ -43,6 +43,10 @@ NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN=${NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN:-f
|
||||
NETBIRD_DISABLE_ANONYMOUS_METRICS=${NETBIRD_DISABLE_ANONYMOUS_METRICS:-false}
|
||||
NETBIRD_TOKEN_SOURCE=${NETBIRD_TOKEN_SOURCE:-accessToken}
|
||||
|
||||
# PKCE authorization flow
|
||||
NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS=${NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS:-"53000"}
|
||||
NETBIRD_AUTH_PKCE_USE_ID_TOKEN=${NETBIRD_AUTH_PKCE_USE_ID_TOKEN:-false}
|
||||
|
||||
# exports
|
||||
export NETBIRD_DOMAIN
|
||||
export NETBIRD_AUTH_CLIENT_ID
|
||||
@@ -80,4 +84,6 @@ export NETBIRD_AUTH_USER_ID_CLAIM
|
||||
export NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
|
||||
export NETBIRD_TOKEN_SOURCE
|
||||
export NETBIRD_AUTH_DEVICE_AUTH_SCOPE
|
||||
export NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN
|
||||
export NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN
|
||||
export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
|
||||
export NETBIRD_AUTH_PKCE_USE_ID_TOKEN
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user