mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-05 17:04:21 -04:00
Compare commits
49 Commits
v0.67.4
...
client-ipv
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a457faa876 | ||
|
|
95f4e90ae8 | ||
|
|
21ae32bb52 | ||
|
|
310b7dd89b | ||
|
|
44d16e8791 | ||
|
|
a9ee25aeef | ||
|
|
ed5cfa6dc5 | ||
|
|
fcf8c4b30e | ||
|
|
faef5a57b3 | ||
|
|
22c4be0c34 | ||
|
|
b6bd2d667b | ||
|
|
571527c2d3 | ||
|
|
c13f1af196 | ||
|
|
e4857b4d9d | ||
|
|
da6f61039a | ||
|
|
76414a1061 | ||
|
|
1cc19e7355 | ||
|
|
1a6cf9dfec | ||
|
|
6acc6a13f1 | ||
|
|
58eb519dbc | ||
|
|
aebf3ceb40 | ||
|
|
bc6ed1a97f | ||
|
|
50c0bc583b | ||
|
|
5e1cdd7d36 | ||
|
|
7fe417c6b4 | ||
|
|
0c2fbd5d70 | ||
|
|
641e3861c1 | ||
|
|
baf2c03508 | ||
|
|
569244a59c | ||
|
|
5a29fa8432 | ||
|
|
5fcea07181 | ||
|
|
3be5a5f230 | ||
|
|
780fd66dd7 | ||
|
|
133f004086 | ||
|
|
d81cd5d154 | ||
|
|
71962f88f8 | ||
|
|
878dc45abf | ||
|
|
1a7e835949 | ||
|
|
b852ce1a99 | ||
|
|
013770070a | ||
|
|
acdf8d981a | ||
|
|
e2f774824b | ||
|
|
3963072c43 | ||
|
|
8550765f38 | ||
|
|
67fb6be40a | ||
|
|
cd7290a497 | ||
|
|
63c19dbf2e | ||
|
|
01c4d5761d | ||
|
|
e916e0d7fa |
4
.github/workflows/wasm-build-validation.yml
vendored
4
.github/workflows/wasm-build-validation.yml
vendored
@@ -61,8 +61,8 @@ jobs:
|
||||
|
||||
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||
|
||||
if [ ${SIZE} -gt 57671680 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!"
|
||||
if [ ${SIZE} -gt 58720256 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
@@ -203,10 +203,11 @@ func (c *Client) PeersList() *PeerInfoArray {
|
||||
peerInfos := make([]PeerInfo, len(fullStatus.Peers))
|
||||
for n, p := range fullStatus.Peers {
|
||||
pi := PeerInfo{
|
||||
p.IP,
|
||||
p.FQDN,
|
||||
p.ConnStatus.String(),
|
||||
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
|
||||
IP: p.IP,
|
||||
IPv6: p.IPv6,
|
||||
FQDN: p.FQDN,
|
||||
ConnStatus: p.ConnStatus.String(),
|
||||
Routes: PeerRoutes{routes: maps.Keys(p.GetRoutes())},
|
||||
}
|
||||
peerInfos[n] = pi
|
||||
}
|
||||
@@ -237,43 +238,84 @@ func (c *Client) Networks() *NetworkArray {
|
||||
return nil
|
||||
}
|
||||
|
||||
routesMap := routeManager.GetClientRoutesWithNetID()
|
||||
v6Merged := route.V6ExitMergeSet(routesMap)
|
||||
resolvedDomains := c.recorder.GetResolvedDomainsStates()
|
||||
|
||||
networkArray := &NetworkArray{
|
||||
items: make([]Network, 0),
|
||||
}
|
||||
|
||||
resolvedDomains := c.recorder.GetResolvedDomainsStates()
|
||||
|
||||
for id, routes := range routeManager.GetClientRoutesWithNetID() {
|
||||
for id, routes := range routesMap {
|
||||
if len(routes) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
r := routes[0]
|
||||
domains := c.getNetworkDomainsFromRoute(r, resolvedDomains)
|
||||
netStr := r.Network.String()
|
||||
|
||||
if r.IsDynamic() {
|
||||
netStr = r.Domains.SafeString()
|
||||
}
|
||||
|
||||
routePeer, err := c.recorder.GetPeer(routes[0].Peer)
|
||||
if err != nil {
|
||||
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
|
||||
if _, skip := v6Merged[id]; skip {
|
||||
continue
|
||||
}
|
||||
network := Network{
|
||||
Name: string(id),
|
||||
Network: netStr,
|
||||
Peer: routePeer.FQDN,
|
||||
Status: routePeer.ConnStatus.String(),
|
||||
IsSelected: routeSelector.IsSelected(id),
|
||||
Domains: domains,
|
||||
|
||||
network := c.buildNetwork(id, routes, routeSelector.IsSelected(id), resolvedDomains, v6Merged)
|
||||
if network == nil {
|
||||
continue
|
||||
}
|
||||
networkArray.Add(network)
|
||||
networkArray.Add(*network)
|
||||
}
|
||||
return networkArray
|
||||
}
|
||||
|
||||
func (c *Client) buildNetwork(id route.NetID, routes []*route.Route, selected bool, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo, v6Merged map[route.NetID]struct{}) *Network {
|
||||
r := routes[0]
|
||||
netStr := r.Network.String()
|
||||
if r.IsDynamic() {
|
||||
netStr = r.Domains.SafeString()
|
||||
}
|
||||
|
||||
routePeer, err := c.findBestRoutePeer(routes)
|
||||
if err != nil {
|
||||
log.Errorf("could not get peer info for route %s: %v", id, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
network := &Network{
|
||||
Name: string(id),
|
||||
Network: netStr,
|
||||
Peer: routePeer.FQDN,
|
||||
Status: routePeer.ConnStatus.String(),
|
||||
IsSelected: selected,
|
||||
Domains: c.getNetworkDomainsFromRoute(r, resolvedDomains),
|
||||
}
|
||||
|
||||
if route.IsV4DefaultRoute(r.Network) && route.HasV6ExitPair(id, v6Merged) {
|
||||
network.Network = "0.0.0.0/0, ::/0"
|
||||
}
|
||||
|
||||
return network
|
||||
}
|
||||
|
||||
// findBestRoutePeer returns the peer actively routing traffic for the given
|
||||
// HA route group. Falls back to the first connected peer, then the first peer.
|
||||
func (c *Client) findBestRoutePeer(routes []*route.Route) (peer.State, error) {
|
||||
netStr := routes[0].Network.String()
|
||||
|
||||
fullStatus := c.recorder.GetFullStatus()
|
||||
for _, p := range fullStatus.Peers {
|
||||
if _, ok := p.GetRoutes()[netStr]; ok {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
|
||||
for _, r := range routes {
|
||||
p, err := c.recorder.GetPeer(r.Peer)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if p.ConnStatus == peer.StatusConnected {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
return c.recorder.GetPeer(routes[0].Peer)
|
||||
}
|
||||
|
||||
// OnUpdatedHostDNS update the DNS servers addresses for root zones
|
||||
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
||||
dnsServer, err := dns.GetServerDns()
|
||||
|
||||
@@ -5,6 +5,7 @@ package android
|
||||
// PeerInfo describe information about the peers. It designed for the UI usage
|
||||
type PeerInfo struct {
|
||||
IP string
|
||||
IPv6 string
|
||||
FQDN string
|
||||
ConnStatus string // Todo replace to enum
|
||||
Routes PeerRoutes
|
||||
|
||||
@@ -307,6 +307,24 @@ func (p *Preferences) SetBlockInbound(block bool) {
|
||||
p.configInput.BlockInbound = &block
|
||||
}
|
||||
|
||||
// GetDisableIPv6 reads disable IPv6 setting from config file
|
||||
func (p *Preferences) GetDisableIPv6() (bool, error) {
|
||||
if p.configInput.DisableIPv6 != nil {
|
||||
return *p.configInput.DisableIPv6, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.DisableIPv6, err
|
||||
}
|
||||
|
||||
// SetDisableIPv6 stores the given value and waits for commit
|
||||
func (p *Preferences) SetDisableIPv6(disable bool) {
|
||||
p.configInput.DisableIPv6 = &disable
|
||||
}
|
||||
|
||||
// Commit writes out the changes to the config file
|
||||
func (p *Preferences) Commit() error {
|
||||
_, err := profilemanager.UpdateOrCreateConfig(p.configInput)
|
||||
|
||||
@@ -18,9 +18,12 @@ func executeRouteToggle(id string, manager routemanager.Manager,
|
||||
netID := route.NetID(id)
|
||||
routes := []route.NetID{netID}
|
||||
|
||||
log.Debugf("%s with id: %s", operationName, id)
|
||||
routesMap := manager.GetClientRoutesWithNetID()
|
||||
routes = route.ExpandV6ExitPairs(routes, routesMap)
|
||||
|
||||
if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil {
|
||||
log.Debugf("%s with ids: %v", operationName, routes)
|
||||
|
||||
if err := routeOperation(routes, maps.Keys(routesMap)); err != nil {
|
||||
log.Debugf("error when %s: %s", operationName, err)
|
||||
return fmt.Errorf("error %s: %w", operationName, err)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"net/url"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -26,8 +27,9 @@ type Anonymizer struct {
|
||||
}
|
||||
|
||||
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
||||
// 198.51.100.0, 100::
|
||||
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
|
||||
// 198.51.100.0 (RFC 5737 TEST-NET-2), 2001:db8:ffff:: (RFC 3849 documentation, last /48)
|
||||
// The old start 100:: (discard, RFC 6666) is now used for fake IPs on Android.
|
||||
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.MustParseAddr("2001:db8:ffff::")
|
||||
}
|
||||
|
||||
func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
|
||||
@@ -96,6 +98,11 @@ func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
|
||||
}
|
||||
|
||||
func (a *Anonymizer) AnonymizeIPString(ip string) string {
|
||||
// Handle CIDR notation (e.g. "2001:db8::/32")
|
||||
if prefix, err := netip.ParsePrefix(ip); err == nil {
|
||||
return a.AnonymizeIP(prefix.Addr()).String() + "/" + strconv.Itoa(prefix.Bits())
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(ip)
|
||||
if err != nil {
|
||||
return ip
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
|
||||
func TestAnonymizeIP(t *testing.T) {
|
||||
startIPv4 := netip.MustParseAddr("198.51.100.0")
|
||||
startIPv6 := netip.MustParseAddr("100::")
|
||||
startIPv6 := netip.MustParseAddr("2001:db8:ffff::")
|
||||
anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6)
|
||||
|
||||
tests := []struct {
|
||||
@@ -26,9 +26,9 @@ func TestAnonymizeIP(t *testing.T) {
|
||||
{"Second Public IPv4", "4.3.2.1", "198.51.100.1"},
|
||||
{"Repeated IPv4", "1.2.3.4", "198.51.100.0"},
|
||||
{"Private IPv4", "192.168.1.1", "192.168.1.1"},
|
||||
{"First Public IPv6", "2607:f8b0:4005:805::200e", "100::"},
|
||||
{"Second Public IPv6", "a::b", "100::1"},
|
||||
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "100::"},
|
||||
{"First Public IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"},
|
||||
{"Second Public IPv6", "a::b", "2001:db8:ffff::1"},
|
||||
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"},
|
||||
{"Private IPv6", "fe80::1", "fe80::1"},
|
||||
{"In Range IPv4", "198.51.100.2", "198.51.100.2"},
|
||||
}
|
||||
@@ -274,17 +274,17 @@ func TestAnonymizeString_IPAddresses(t *testing.T) {
|
||||
{
|
||||
name: "IPv6 Address",
|
||||
input: "Access attempted from 2001:db8::ff00:42",
|
||||
expect: "Access attempted from 100::",
|
||||
expect: "Access attempted from 2001:db8:ffff::",
|
||||
},
|
||||
{
|
||||
name: "IPv6 Address with Port",
|
||||
input: "Access attempted from [2001:db8::ff00:42]:8080",
|
||||
expect: "Access attempted from [100::]:8080",
|
||||
expect: "Access attempted from [2001:db8:ffff::]:8080",
|
||||
},
|
||||
{
|
||||
name: "Both IPv4 and IPv6",
|
||||
input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43",
|
||||
expect: "IPv4: 198.51.100.1 and IPv6: 100::1",
|
||||
expect: "IPv4: 198.51.100.1 and IPv6: 2001:db8:ffff::1",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -787,10 +787,10 @@ func isUnixSocket(path string) bool {
|
||||
return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./")
|
||||
}
|
||||
|
||||
// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces.
|
||||
// normalizeLocalHost converts "*" to "" for binding to all interfaces (dual-stack).
|
||||
func normalizeLocalHost(host string) string {
|
||||
if host == "*" {
|
||||
return "0.0.0.0"
|
||||
return ""
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
@@ -527,10 +527,10 @@ func TestParsePortForward(t *testing.T) {
|
||||
{
|
||||
name: "wildcard bind all interfaces",
|
||||
spec: "*:8080:localhost:80",
|
||||
expectedLocal: "0.0.0.0:8080",
|
||||
expectedLocal: ":8080",
|
||||
expectedRemote: "localhost:80",
|
||||
expectError: false,
|
||||
description: "Wildcard * should bind to all interfaces (0.0.0.0)",
|
||||
description: "Wildcard * should bind to all interfaces (dual-stack)",
|
||||
},
|
||||
{
|
||||
name: "wildcard for port only",
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
var (
|
||||
detailFlag bool
|
||||
ipv4Flag bool
|
||||
ipv6Flag bool
|
||||
jsonFlag bool
|
||||
yamlFlag bool
|
||||
ipsFilter []string
|
||||
@@ -45,8 +46,9 @@ func init() {
|
||||
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
|
||||
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
|
||||
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
|
||||
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
||||
statusCmd.PersistentFlags().BoolVar(&ipv6Flag, "ipv6", false, "display only NetBird IPv6 of this peer")
|
||||
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4", "ipv6")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
|
||||
@@ -101,6 +103,14 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ipv6Flag {
|
||||
ipv6 := resp.GetFullStatus().GetLocalPeerState().GetIpv6()
|
||||
if ipv6 != "" {
|
||||
cmd.Print(parseInterfaceIP(ipv6))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
|
||||
@@ -8,6 +8,7 @@ const (
|
||||
disableFirewallFlag = "disable-firewall"
|
||||
blockLANAccessFlag = "block-lan-access"
|
||||
blockInboundFlag = "block-inbound"
|
||||
disableIPv6Flag = "disable-ipv6"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -17,6 +18,7 @@ var (
|
||||
disableFirewall bool
|
||||
blockLANAccess bool
|
||||
blockInbound bool
|
||||
disableIPv6 bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -39,4 +41,7 @@ func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
|
||||
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
|
||||
"This overrides any policies received from the management service.")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&disableIPv6, disableIPv6Flag, false,
|
||||
"Disable IPv6 overlay. If enabled, the client won't request or use an IPv6 overlay address.")
|
||||
}
|
||||
|
||||
@@ -430,6 +430,10 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
||||
req.BlockInbound = &blockInbound
|
||||
}
|
||||
|
||||
if cmd.Flag(disableIPv6Flag).Changed {
|
||||
req.DisableIpv6 = &disableIPv6
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
req.LazyConnectionEnabled = &lazyConnEnabled
|
||||
}
|
||||
@@ -547,6 +551,10 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
ic.BlockInbound = &blockInbound
|
||||
}
|
||||
|
||||
if cmd.Flag(disableIPv6Flag).Changed {
|
||||
ic.DisableIPv6 = &disableIPv6
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
ic.LazyConnectionEnabled = &lazyConnEnabled
|
||||
}
|
||||
@@ -661,6 +669,10 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
loginRequest.BlockInbound = &blockInbound
|
||||
}
|
||||
|
||||
if cmd.Flag(disableIPv6Flag).Changed {
|
||||
loginRequest.DisableIpv6 = &disableIPv6
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
|
||||
}
|
||||
|
||||
@@ -79,6 +79,8 @@ type Options struct {
|
||||
StatePath string
|
||||
// DisableClientRoutes disables the client routes
|
||||
DisableClientRoutes bool
|
||||
// DisableIPv6 disables IPv6 overlay addressing
|
||||
DisableIPv6 bool
|
||||
// BlockInbound blocks all inbound connections from peers
|
||||
BlockInbound bool
|
||||
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
|
||||
@@ -170,6 +172,7 @@ func New(opts Options) (*Client, error) {
|
||||
PreSharedKey: &opts.PreSharedKey,
|
||||
DisableServerRoutes: &t,
|
||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||
DisableIPv6: &opts.DisableIPv6,
|
||||
BlockInbound: &opts.BlockInbound,
|
||||
WireguardPort: opts.WireguardPort,
|
||||
MTU: opts.MTU,
|
||||
|
||||
@@ -36,6 +36,7 @@ type aclManager struct {
|
||||
entries aclEntries
|
||||
optionalEntries map[string][]entry
|
||||
ipsetStore *ipsetStore
|
||||
v6 bool
|
||||
|
||||
stateManager *statemanager.Manager
|
||||
}
|
||||
@@ -47,6 +48,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*acl
|
||||
entries: make(map[string][][]string),
|
||||
optionalEntries: make(map[string][]entry),
|
||||
ipsetStore: newIpsetStore(),
|
||||
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -81,7 +83,11 @@ func (m *aclManager) AddPeerFiltering(
|
||||
chain := chainNameInputRules
|
||||
|
||||
ipsetName = transformIPsetName(ipsetName, sPort, dPort, action)
|
||||
specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName)
|
||||
if m.v6 && ipsetName != "" {
|
||||
ipsetName += "-v6"
|
||||
}
|
||||
proto := protoForFamily(protocol, m.v6)
|
||||
specs := filterRuleSpecs(ip, proto, sPort, dPort, action, ipsetName)
|
||||
|
||||
mangleSpecs := slices.Clone(specs)
|
||||
mangleSpecs = append(mangleSpecs,
|
||||
@@ -105,6 +111,7 @@ func (m *aclManager) AddPeerFiltering(
|
||||
ip: ip.String(),
|
||||
chain: chain,
|
||||
specs: specs,
|
||||
v6: m.v6,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
@@ -157,6 +164,7 @@ func (m *aclManager) AddPeerFiltering(
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
chain: chain,
|
||||
v6: m.v6,
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
@@ -376,8 +384,13 @@ func (m *aclManager) updateState() {
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
currentState.ACLEntries = m.entries
|
||||
currentState.ACLIPsetStore = m.ipsetStore
|
||||
if m.v6 {
|
||||
currentState.ACLEntries6 = m.entries
|
||||
currentState.ACLIPsetStore6 = m.ipsetStore
|
||||
} else {
|
||||
currentState.ACLEntries = m.entries
|
||||
currentState.ACLIPsetStore = m.ipsetStore
|
||||
}
|
||||
|
||||
if err := m.stateManager.UpdateState(currentState); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
@@ -385,6 +398,15 @@ func (m *aclManager) updateState() {
|
||||
}
|
||||
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
|
||||
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
|
||||
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
|
||||
if v6 && protocol == firewall.ProtocolICMP {
|
||||
return "ipv6-icmp"
|
||||
}
|
||||
return string(protocol)
|
||||
}
|
||||
|
||||
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
||||
// don't use IP matching if IP is 0.0.0.0
|
||||
matchByIP := !ip.IsUnspecified()
|
||||
@@ -437,6 +459,9 @@ func (m *aclManager) createIPSet(name string) error {
|
||||
opts := ipset.CreateOptions{
|
||||
Replace: true,
|
||||
}
|
||||
if m.v6 {
|
||||
opts.Family = ipset.FamilyIPV6
|
||||
}
|
||||
|
||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||
|
||||
@@ -27,6 +27,11 @@ type Manager struct {
|
||||
aclMgr *aclManager
|
||||
router *router
|
||||
rawSupported bool
|
||||
|
||||
// IPv6 counterparts, nil when no v6 overlay
|
||||
ipv6Client *iptables.IPTables
|
||||
aclMgr6 *aclManager
|
||||
router6 *router
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
@@ -58,9 +63,39 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||
}
|
||||
|
||||
if wgIface.Address().HasIPv6() {
|
||||
if err := m.createIPv6Components(wgIface, mtu); err != nil {
|
||||
return nil, fmt.Errorf("create IPv6 firewall: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error {
|
||||
ip6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init ip6tables: %w", err)
|
||||
}
|
||||
m.ipv6Client = ip6Client
|
||||
|
||||
m.router6, err = newRouter(ip6Client, wgIface, mtu)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 router: %w", err)
|
||||
}
|
||||
|
||||
m.aclMgr6, err = newAclManager(ip6Client, wgIface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 acl manager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) hasIPv6() bool {
|
||||
return m.ipv6Client != nil
|
||||
}
|
||||
|
||||
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
state := &ShutdownState{
|
||||
InterfaceState: &InterfaceState{
|
||||
@@ -84,6 +119,15 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
return fmt.Errorf("acl manager init: %w", err)
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
if err := m.router6.init(stateManager); err != nil {
|
||||
return fmt.Errorf("v6 router init: %w", err)
|
||||
}
|
||||
if err := m.aclMgr6.init(stateManager); err != nil {
|
||||
return fmt.Errorf("v6 acl manager init: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.initNoTrackChain(); err != nil {
|
||||
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
|
||||
}
|
||||
@@ -113,7 +157,13 @@ func (m *Manager) AddPeerFiltering(
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
if ip.To4() != nil {
|
||||
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("IPv6 not initialized, cannot add rule for %s", ip)
|
||||
}
|
||||
return m.aclMgr6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
@@ -127,25 +177,48 @@ func (m *Manager) AddRouteFiltering(
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
|
||||
if isIPv6RouteRule(sources, destination) {
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("IPv6 not initialized, cannot add route rule")
|
||||
}
|
||||
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
|
||||
if destination.IsPrefix() {
|
||||
return destination.Prefix.Addr().Is6()
|
||||
}
|
||||
return len(sources) > 0 && sources[0].Addr().Is6()
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && isIPv6IptRule(rule) {
|
||||
return m.aclMgr6.DeletePeerRule(rule)
|
||||
}
|
||||
return m.aclMgr.DeletePeerRule(rule)
|
||||
}
|
||||
|
||||
func isIPv6IptRule(rule firewall.Rule) bool {
|
||||
r, ok := rule.(*Rule)
|
||||
return ok && r.v6
|
||||
}
|
||||
|
||||
// DeleteRouteRule deletes a routing rule.
|
||||
// Route rules are keyed by content hash. Check v4 first, try v6 if not found.
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && !m.router.hasRule(rule.ID()) {
|
||||
return m.router6.DeleteRouteRule(rule)
|
||||
}
|
||||
return m.router.DeleteRouteRule(rule)
|
||||
}
|
||||
|
||||
@@ -161,18 +234,63 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddNatRule(pair)
|
||||
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("IPv6 not initialized, cannot add NAT rule")
|
||||
}
|
||||
return m.router6.AddNatRule(pair)
|
||||
}
|
||||
|
||||
if err := m.router.AddNatRule(pair); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Dynamic routes need NAT in both tables
|
||||
if m.hasIPv6() && pair.Destination.IsSet() {
|
||||
v6Pair := pair
|
||||
v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||
if err := m.router6.AddNatRule(v6Pair); err != nil {
|
||||
return fmt.Errorf("add v6 NAT rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveNatRule(pair)
|
||||
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
|
||||
if !m.hasIPv6() {
|
||||
return nil
|
||||
}
|
||||
return m.router6.RemoveNatRule(pair)
|
||||
}
|
||||
|
||||
if err := m.router.RemoveNatRule(pair); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.hasIPv6() && pair.Destination.IsSet() {
|
||||
v6Pair := pair
|
||||
v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
|
||||
return fmt.Errorf("remove v6 NAT rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
|
||||
return err
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
return firewall.SetLegacyManagement(m.router6, isLegacy)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
@@ -186,6 +304,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
if err := m.aclMgr6.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 acl manager: %w", err))
|
||||
}
|
||||
if err := m.router6.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.aclMgr.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
||||
}
|
||||
@@ -209,19 +336,16 @@ func (m *Manager) AllowNetbird() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := m.AddPeerFiltering(
|
||||
nil,
|
||||
net.IP{0, 0, 0, 0},
|
||||
firewall.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
firewall.ActionAccept,
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||
var merr *multierror.Error
|
||||
if _, err := m.aclMgr.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("allow netbird interface traffic: %w", err))
|
||||
}
|
||||
return nil
|
||||
if m.hasIPv6() {
|
||||
if _, err := m.aclMgr6.AddPeerFiltering(nil, net.IPv6zero, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("allow v6 netbird interface traffic: %w", err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
@@ -251,6 +375,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && rule.TranslatedAddress.Is6() {
|
||||
return m.router6.AddDNATRule(rule)
|
||||
}
|
||||
return m.router.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
@@ -259,6 +386,9 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && !m.router.hasRule(rule.ID()+dnatSuffix) {
|
||||
return m.router6.DeleteDNATRule(rule)
|
||||
}
|
||||
return m.router.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
@@ -267,7 +397,26 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.UpdateSet(set, prefixes)
|
||||
var v4Prefixes, v6Prefixes []netip.Prefix
|
||||
for _, p := range prefixes {
|
||||
if p.Addr().Is6() {
|
||||
v6Prefixes = append(v6Prefixes, p)
|
||||
} else {
|
||||
v4Prefixes = append(v4Prefixes, p)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.hasIPv6() && len(v6Prefixes) > 0 {
|
||||
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
|
||||
return fmt.Errorf("update v6 set: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
@@ -275,6 +424,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && localAddr.Is6() {
|
||||
return m.router6.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
@@ -283,6 +435,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && localAddr.Is6() {
|
||||
return m.router6.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
|
||||
@@ -52,8 +52,10 @@ const (
|
||||
snatSuffix = "_snat"
|
||||
fwdSuffix = "_fwd"
|
||||
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
|
||||
ipv4TCPHeaderSize = 40
|
||||
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
|
||||
ipv6TCPHeaderSize = 60
|
||||
)
|
||||
|
||||
type ruleInfo struct {
|
||||
@@ -84,6 +86,7 @@ type router struct {
|
||||
wgIface iFaceMapper
|
||||
legacyManagement bool
|
||||
mtu uint16
|
||||
v6 bool
|
||||
|
||||
stateManager *statemanager.Manager
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
@@ -95,6 +98,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1
|
||||
rules: make(map[string][]string),
|
||||
wgIface: wgIface,
|
||||
mtu: mtu,
|
||||
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||
}
|
||||
|
||||
@@ -184,6 +188,11 @@ func (r *router) AddRouteFiltering(
|
||||
return ruleKey, nil
|
||||
}
|
||||
|
||||
func (r *router) hasRule(id string) bool {
|
||||
_, ok := r.rules[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
ruleKey := rule.ID()
|
||||
|
||||
@@ -423,6 +432,12 @@ func (r *router) createContainers() error {
|
||||
{chainRTRDR, tableNat},
|
||||
{chainRTMSSCLAMP, tableMangle},
|
||||
} {
|
||||
// Fallback: clear chains that survived an unclean shutdown.
|
||||
if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok {
|
||||
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||
log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err)
|
||||
}
|
||||
}
|
||||
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||
}
|
||||
@@ -529,9 +544,12 @@ func (r *router) addPostroutingRules() error {
|
||||
}
|
||||
|
||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||
// TODO: Add IPv6 support
|
||||
func (r *router) addMSSClampingRules() error {
|
||||
mss := r.mtu - ipTCPHeaderMinSize
|
||||
overhead := uint16(ipv4TCPHeaderSize)
|
||||
if r.v6 {
|
||||
overhead = ipv6TCPHeaderSize
|
||||
}
|
||||
mss := r.mtu - overhead
|
||||
|
||||
// Add jump rule from FORWARD chain in mangle table to our custom chain
|
||||
jumpRule := []string{
|
||||
@@ -716,8 +734,13 @@ func (r *router) updateState() {
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
currentState.RouteRules = r.rules
|
||||
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||
if r.v6 {
|
||||
currentState.RouteRules6 = r.rules
|
||||
currentState.RouteIPsetCounter6 = r.ipsetCounter
|
||||
} else {
|
||||
currentState.RouteRules = r.rules
|
||||
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||
}
|
||||
|
||||
if err := r.stateManager.UpdateState(currentState); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
@@ -845,7 +868,7 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||
}
|
||||
|
||||
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
|
||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil {
|
||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDOUT, fwdRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
|
||||
}
|
||||
delete(r.rules, ruleKey+fwdSuffix)
|
||||
@@ -872,7 +895,7 @@ func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []net
|
||||
rule = append(rule, destExp...)
|
||||
|
||||
if params.Proto != firewall.ProtocolALL {
|
||||
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
|
||||
rule = append(rule, "-p", strings.ToLower(protoForFamily(params.Proto, r.v6)))
|
||||
rule = append(rule, applyPort("--sport", params.SPort)...)
|
||||
rule = append(rule, applyPort("--dport", params.DPort)...)
|
||||
}
|
||||
@@ -889,11 +912,12 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []
|
||||
}
|
||||
|
||||
if network.IsSet() {
|
||||
if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil {
|
||||
name := r.ipsetName(network.Set.HashedName())
|
||||
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||
}
|
||||
|
||||
return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil
|
||||
return []string{"-m", "set", matchSet, name, direction}, nil
|
||||
}
|
||||
if network.IsPrefix() {
|
||||
return []string{flag, network.Prefix.String()}, nil
|
||||
@@ -904,19 +928,15 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []
|
||||
}
|
||||
|
||||
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
name := r.ipsetName(set.HashedName())
|
||||
var merr *multierror.Error
|
||||
for _, prefix := range prefixes {
|
||||
// TODO: Implement IPv6 support
|
||||
if prefix.Addr().Is6() {
|
||||
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||
continue
|
||||
}
|
||||
if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil {
|
||||
if err := r.addPrefixToIPSet(name, prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
|
||||
}
|
||||
}
|
||||
if merr == nil {
|
||||
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
|
||||
log.Debugf("updated set %s with prefixes %v", name, prefixes)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
@@ -932,7 +952,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
||||
|
||||
dnatRule := []string{
|
||||
"-i", r.wgIface.Name(),
|
||||
"-p", strings.ToLower(string(protocol)),
|
||||
"-p", strings.ToLower(protoForFamily(protocol, r.v6)),
|
||||
"--dport", strconv.Itoa(int(sourcePort)),
|
||||
"-d", localAddr.String(),
|
||||
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||
@@ -990,10 +1010,22 @@ func applyPort(flag string, port *firewall.Port) []string {
|
||||
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
||||
}
|
||||
|
||||
// ipsetName returns the ipset name, suffixed with "-v6" for the v6 router
|
||||
// to avoid collisions since ipsets are global in the kernel.
|
||||
func (r *router) ipsetName(name string) string {
|
||||
if r.v6 {
|
||||
return name + "-v6"
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func (r *router) createIPSet(name string) error {
|
||||
opts := ipset.CreateOptions{
|
||||
Replace: true,
|
||||
}
|
||||
if r.v6 {
|
||||
opts.Family = ipset.FamilyIPV6
|
||||
}
|
||||
|
||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||
|
||||
@@ -9,6 +9,7 @@ type Rule struct {
|
||||
mangleSpecs []string
|
||||
ip string
|
||||
chain string
|
||||
v6 bool
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
@@ -37,6 +39,12 @@ type ShutdownState struct {
|
||||
|
||||
ACLEntries aclEntries `json:"acl_entries,omitempty"`
|
||||
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
|
||||
|
||||
// IPv6 counterparts
|
||||
RouteRules6 routeRules `json:"route_rules_v6,omitempty"`
|
||||
RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"`
|
||||
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
|
||||
ACLIPsetStore6 *ipsetStore `json:"acl_ipset_store_v6,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Name() string {
|
||||
@@ -67,6 +75,28 @@ func (s *ShutdownState) Cleanup() error {
|
||||
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
|
||||
}
|
||||
|
||||
// Clean up v6 state even if the current run has no IPv6.
|
||||
// The previous run may have left ip6tables rules behind.
|
||||
if !ipt.hasIPv6() {
|
||||
if err := ipt.createIPv6Components(s.InterfaceState, mtu); err != nil {
|
||||
log.Warnf("failed to create v6 components for cleanup: %v", err)
|
||||
}
|
||||
}
|
||||
if ipt.hasIPv6() {
|
||||
if s.RouteRules6 != nil {
|
||||
ipt.router6.rules = s.RouteRules6
|
||||
}
|
||||
if s.RouteIPsetCounter6 != nil {
|
||||
ipt.router6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
|
||||
}
|
||||
if s.ACLEntries6 != nil {
|
||||
ipt.aclMgr6.entries = s.ACLEntries6
|
||||
}
|
||||
if s.ACLIPsetStore6 != nil {
|
||||
ipt.aclMgr6.ipsetStore = s.ACLIPsetStore6
|
||||
}
|
||||
}
|
||||
|
||||
if err := ipt.Close(nil); err != nil {
|
||||
return fmt.Errorf("reset iptables manager: %w", err)
|
||||
}
|
||||
|
||||
@@ -33,15 +33,12 @@ const (
|
||||
|
||||
const flushError = "flush: %w"
|
||||
|
||||
var (
|
||||
anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
)
|
||||
|
||||
type AclManager struct {
|
||||
rConn *nftables.Conn
|
||||
sConn *nftables.Conn
|
||||
wgIface iFaceMapper
|
||||
routingFwChainName string
|
||||
af addrFamily
|
||||
|
||||
workTable *nftables.Table
|
||||
chainInputRules *nftables.Chain
|
||||
@@ -67,6 +64,7 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam
|
||||
wgIface: wgIface,
|
||||
workTable: table,
|
||||
routingFwChainName: routingFwChainName,
|
||||
af: familyForAddr(table.Family == nftables.TableFamilyIPv4),
|
||||
|
||||
ipsetStore: newIpsetStore(),
|
||||
rules: make(map[string]*Rule),
|
||||
@@ -145,7 +143,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
}
|
||||
|
||||
if _, ok := ips[r.ip.String()]; ok {
|
||||
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}})
|
||||
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: ipToBytes(r.ip, m.af)}})
|
||||
if err != nil {
|
||||
log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err)
|
||||
}
|
||||
@@ -254,11 +252,11 @@ func (m *AclManager) addIOFiltering(
|
||||
expressions = append(expressions, &expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: uint32(9),
|
||||
Offset: m.af.protoOffset,
|
||||
Len: uint32(1),
|
||||
})
|
||||
|
||||
protoData, err := protoToInt(proto)
|
||||
protoData, err := m.af.protoNum(proto)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %v", err)
|
||||
}
|
||||
@@ -270,19 +268,16 @@ func (m *AclManager) addIOFiltering(
|
||||
})
|
||||
}
|
||||
|
||||
rawIP := ip.To4()
|
||||
rawIP := ipToBytes(ip, m.af)
|
||||
// check if rawIP contains zeroed IPv4 0.0.0.0 value
|
||||
// in that case not add IP match expression into the rule definition
|
||||
if !bytes.HasPrefix(anyIP, rawIP) {
|
||||
// source address position
|
||||
addrOffset := uint32(12)
|
||||
|
||||
if slices.ContainsFunc(rawIP, func(v byte) bool { return v != 0 }) {
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: addrOffset,
|
||||
Len: 4,
|
||||
Offset: m.af.srcAddrOffset,
|
||||
Len: m.af.addrLen,
|
||||
},
|
||||
)
|
||||
// add individual IP for match if no ipset defined
|
||||
@@ -587,7 +582,7 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr
|
||||
|
||||
func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) {
|
||||
ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName)
|
||||
rawIP := ip.To4()
|
||||
rawIP := ipToBytes(ip, m.af)
|
||||
if err != nil {
|
||||
if ipset, err = m.createSet(m.workTable, ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("get set name: %v", err)
|
||||
@@ -619,7 +614,7 @@ func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Se
|
||||
Name: name,
|
||||
Table: table,
|
||||
Dynamic: true,
|
||||
KeyType: nftables.TypeIPAddr,
|
||||
KeyType: m.af.setKeyType,
|
||||
}
|
||||
|
||||
if err := m.rConn.AddSet(ipset, nil); err != nil {
|
||||
@@ -707,15 +702,12 @@ func ifname(n string) []byte {
|
||||
return b
|
||||
}
|
||||
|
||||
func protoToInt(protocol firewall.Protocol) (uint8, error) {
|
||||
switch protocol {
|
||||
case firewall.ProtocolTCP:
|
||||
return unix.IPPROTO_TCP, nil
|
||||
case firewall.ProtocolUDP:
|
||||
return unix.IPPROTO_UDP, nil
|
||||
case firewall.ProtocolICMP:
|
||||
return unix.IPPROTO_ICMP, nil
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
|
||||
// ipToBytes converts net.IP to the correct byte length for the address family.
|
||||
func ipToBytes(ip net.IP, af addrFamily) []byte {
|
||||
if af.addrLen == 4 {
|
||||
return ip.To4()
|
||||
}
|
||||
return ip.To16()
|
||||
}
|
||||
|
||||
|
||||
81
client/firewall/nftables/addr_family_linux.go
Normal file
81
client/firewall/nftables/addr_family_linux.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
// addrFamily holds protocol-specific constants for nftables expression building.
|
||||
type addrFamily struct {
|
||||
// protoOffset is the IP header offset for the protocol/next-header field (9 for v4, 6 for v6)
|
||||
protoOffset uint32
|
||||
// srcAddrOffset is the IP header offset for the source address (12 for v4, 8 for v6)
|
||||
srcAddrOffset uint32
|
||||
// dstAddrOffset is the IP header offset for the destination address (16 for v4, 24 for v6)
|
||||
dstAddrOffset uint32
|
||||
// addrLen is the byte length of addresses (4 for v4, 16 for v6)
|
||||
addrLen uint32
|
||||
// totalBits is the address size in bits (32 for v4, 128 for v6)
|
||||
totalBits int
|
||||
// setKeyType is the nftables set data type for addresses
|
||||
setKeyType nftables.SetDatatype
|
||||
// tableFamily is the nftables table family
|
||||
tableFamily nftables.TableFamily
|
||||
// icmpProto is the ICMP protocol number for this family (1 for v4, 58 for v6)
|
||||
icmpProto uint8
|
||||
}
|
||||
|
||||
var (
|
||||
// afIPv4 defines IPv4 header layout and nftables types.
|
||||
afIPv4 = addrFamily{
|
||||
protoOffset: 9,
|
||||
srcAddrOffset: 12,
|
||||
dstAddrOffset: 16,
|
||||
addrLen: net.IPv4len,
|
||||
totalBits: 8 * net.IPv4len,
|
||||
setKeyType: nftables.TypeIPAddr,
|
||||
tableFamily: nftables.TableFamilyIPv4,
|
||||
icmpProto: unix.IPPROTO_ICMP,
|
||||
}
|
||||
// afIPv6 defines IPv6 header layout and nftables types.
|
||||
afIPv6 = addrFamily{
|
||||
protoOffset: 6,
|
||||
srcAddrOffset: 8,
|
||||
dstAddrOffset: 24,
|
||||
addrLen: net.IPv6len,
|
||||
totalBits: 8 * net.IPv6len,
|
||||
setKeyType: nftables.TypeIP6Addr,
|
||||
tableFamily: nftables.TableFamilyIPv6,
|
||||
icmpProto: unix.IPPROTO_ICMPV6,
|
||||
}
|
||||
)
|
||||
|
||||
// familyForAddr returns the address family for the given IP.
|
||||
func familyForAddr(is4 bool) addrFamily {
|
||||
if is4 {
|
||||
return afIPv4
|
||||
}
|
||||
return afIPv6
|
||||
}
|
||||
|
||||
// protoNum converts a firewall protocol to the IP protocol number,
|
||||
// using the correct ICMP variant for the address family.
|
||||
func (af addrFamily) protoNum(protocol firewall.Protocol) (uint8, error) {
|
||||
switch protocol {
|
||||
case firewall.ProtocolTCP:
|
||||
return unix.IPPROTO_TCP, nil
|
||||
case firewall.ProtocolUDP:
|
||||
return unix.IPPROTO_UDP, nil
|
||||
case firewall.ProtocolICMP:
|
||||
return af.icmpProto, nil
|
||||
case firewall.ProtocolALL:
|
||||
return 0, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
|
||||
}
|
||||
}
|
||||
@@ -49,8 +49,13 @@ type Manager struct {
|
||||
rConn *nftables.Conn
|
||||
wgIface iFaceMapper
|
||||
|
||||
router *router
|
||||
aclManager *AclManager
|
||||
router *router
|
||||
aclManager *AclManager
|
||||
|
||||
// IPv6 counterparts, nil when no v6 overlay
|
||||
router6 *router
|
||||
aclManager6 *AclManager
|
||||
|
||||
notrackOutputChain *nftables.Chain
|
||||
notrackPreroutingChain *nftables.Chain
|
||||
}
|
||||
@@ -62,7 +67,8 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||
wgIface: wgIface,
|
||||
}
|
||||
|
||||
workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}
|
||||
tableName := getTableName()
|
||||
workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}
|
||||
|
||||
var err error
|
||||
m.router, err = newRouter(workTable, wgIface, mtu)
|
||||
@@ -75,9 +81,54 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||
}
|
||||
|
||||
if wgIface.Address().HasIPv6() {
|
||||
if err := m.createIPv6Components(tableName, wgIface, mtu); err != nil {
|
||||
return nil, fmt.Errorf("create IPv6 firewall: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mtu uint16) error {
|
||||
workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6}
|
||||
|
||||
var err error
|
||||
m.router6, err = newRouter(workTable6, wgIface, mtu)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 router: %w", err)
|
||||
}
|
||||
|
||||
m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 acl manager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasIPv6 reports whether the manager has IPv6 components initialized.
|
||||
func (m *Manager) hasIPv6() bool {
|
||||
return m.router6 != nil
|
||||
}
|
||||
|
||||
func (m *Manager) initIPv6() error {
|
||||
workTable6, err := m.createWorkTableFamily(nftables.TableFamilyIPv6)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 work table: %w", err)
|
||||
}
|
||||
|
||||
if err := m.router6.init(workTable6); err != nil {
|
||||
return fmt.Errorf("v6 router init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.aclManager6.init(workTable6); err != nil {
|
||||
return fmt.Errorf("v6 acl manager init: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Init nftables firewall manager
|
||||
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
workTable, err := m.createWorkTable()
|
||||
@@ -94,6 +145,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
return fmt.Errorf("acl manager init: %w", err)
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
if err := m.initIPv6(); err != nil {
|
||||
// Peer has a v6 address: v6 firewall MUST work or we risk fail-open.
|
||||
return fmt.Errorf("init IPv6 firewall (required because peer has IPv6 address): %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.initNoTrackChains(workTable); err != nil {
|
||||
log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err)
|
||||
}
|
||||
@@ -141,12 +199,14 @@ func (m *Manager) AddPeerFiltering(
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
rawIP := ip.To4()
|
||||
if rawIP == nil {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||
if ip.To4() != nil {
|
||||
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("IPv6 not initialized, cannot add rule for %s", ip)
|
||||
}
|
||||
return m.aclManager6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
@@ -160,8 +220,11 @@ func (m *Manager) AddRouteFiltering(
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
|
||||
if isIPv6RouteRule(sources, destination) {
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("IPv6 not initialized, cannot add route rule")
|
||||
}
|
||||
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
@@ -172,14 +235,38 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && isIPv6Rule(rule) {
|
||||
return m.aclManager6.DeletePeerRule(rule)
|
||||
}
|
||||
return m.aclManager.DeletePeerRule(rule)
|
||||
}
|
||||
|
||||
// DeleteRouteRule deletes a routing rule
|
||||
func isIPv6Rule(rule firewall.Rule) bool {
|
||||
r, ok := rule.(*Rule)
|
||||
return ok && r.nftRule != nil && r.nftRule.Table != nil && r.nftRule.Table.Family == nftables.TableFamilyIPv6
|
||||
}
|
||||
|
||||
// isIPv6RouteRule determines whether a route rule belongs to the v6 table.
|
||||
// For static routes, the destination prefix determines the family. For dynamic
|
||||
// routes (DomainSet), the sources determine the family since management
|
||||
// duplicates dynamic rules per family.
|
||||
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
|
||||
if destination.IsPrefix() {
|
||||
return destination.Prefix.Addr().Is6()
|
||||
}
|
||||
return len(sources) > 0 && sources[0].Addr().Is6()
|
||||
}
|
||||
|
||||
// DeleteRouteRule deletes a routing rule.
|
||||
// Route rules are keyed by content hash, so the rule exists in exactly one
|
||||
// router. We check v4 first; if the key isn't there, try v6.
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && !m.router.hasRule(rule.ID()) {
|
||||
return m.router6.DeleteRouteRule(rule)
|
||||
}
|
||||
return m.router.DeleteRouteRule(rule)
|
||||
}
|
||||
|
||||
@@ -195,17 +282,63 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddNatRule(pair)
|
||||
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("IPv6 not initialized, cannot add NAT rule")
|
||||
}
|
||||
return m.router6.AddNatRule(pair)
|
||||
}
|
||||
|
||||
if err := m.router.AddNatRule(pair); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Dynamic routes (DomainSet) need NAT in both tables since resolved IPs
|
||||
// can be either v4 or v6.
|
||||
if m.hasIPv6() && pair.Destination.IsSet() {
|
||||
v6Pair := pair
|
||||
v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||
if err := m.router6.AddNatRule(v6Pair); err != nil {
|
||||
return fmt.Errorf("add v6 NAT rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveNatRule(pair)
|
||||
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
|
||||
if !m.hasIPv6() {
|
||||
return nil
|
||||
}
|
||||
return m.router6.RemoveNatRule(pair)
|
||||
}
|
||||
|
||||
if err := m.router.RemoveNatRule(pair); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.hasIPv6() && pair.Destination.IsSet() {
|
||||
v6Pair := pair
|
||||
v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
|
||||
return fmt.Errorf("remove v6 NAT rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
// AllowNetbird allows netbird interface traffic.
|
||||
// TODO: In USP mode this only adds ACCEPT to the netbird table's own chains,
|
||||
// which doesn't override DROP rules in external tables (e.g. firewalld).
|
||||
// Should add passthrough rules to external chains (like the native mode router's
|
||||
// addExternalChainsRules does) for both the netbird table family and inet tables.
|
||||
// The netbird table itself is fine (routing chains already exist there), but
|
||||
// non-netbird tables with INPUT/FORWARD hooks can still DROP our WG traffic.
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
if !m.wgIface.IsUserspaceBind() {
|
||||
return nil
|
||||
@@ -217,6 +350,11 @@ func (m *Manager) AllowNetbird() error {
|
||||
if err := m.aclManager.createDefaultAllowRules(); err != nil {
|
||||
return fmt.Errorf("create default allow rules: %w", err)
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
if err := m.aclManager6.createDefaultAllowRules(); err != nil {
|
||||
return fmt.Errorf("create v6 default allow rules: %w", err)
|
||||
}
|
||||
}
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
||||
}
|
||||
@@ -226,7 +364,13 @@ func (m *Manager) AllowNetbird() error {
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
|
||||
return err
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
return firewall.SetLegacyManagement(m.router6, isLegacy)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the firewall manager
|
||||
@@ -238,6 +382,12 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
return fmt.Errorf("reset router: %v", err)
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
if err := m.router6.Reset(); err != nil {
|
||||
return fmt.Errorf("reset v6 router: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.cleanupNetbirdTables(); err != nil {
|
||||
return fmt.Errorf("cleanup netbird tables: %v", err)
|
||||
}
|
||||
@@ -299,6 +449,12 @@ func (m *Manager) Flush() error {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
if err := m.aclManager6.Flush(); err != nil {
|
||||
return fmt.Errorf("flush v6 acl: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.refreshNoTrackChains(); err != nil {
|
||||
log.Errorf("failed to refresh notrack chains: %v", err)
|
||||
}
|
||||
@@ -311,6 +467,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && rule.TranslatedAddress.Is6() {
|
||||
return m.router6.AddDNATRule(rule)
|
||||
}
|
||||
return m.router.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
@@ -319,6 +478,9 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && !m.router.hasDNATRule(rule.ID()) {
|
||||
return m.router6.DeleteDNATRule(rule)
|
||||
}
|
||||
return m.router.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
@@ -327,7 +489,26 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.UpdateSet(set, prefixes)
|
||||
var v4Prefixes, v6Prefixes []netip.Prefix
|
||||
for _, p := range prefixes {
|
||||
if p.Addr().Is6() {
|
||||
v6Prefixes = append(v6Prefixes, p)
|
||||
} else {
|
||||
v4Prefixes = append(v4Prefixes, p)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.hasIPv6() && len(v6Prefixes) > 0 {
|
||||
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
|
||||
return fmt.Errorf("update v6 set: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
@@ -335,6 +516,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && localAddr.Is6() {
|
||||
return m.router6.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
@@ -343,6 +527,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && localAddr.Is6() {
|
||||
return m.router6.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
@@ -517,7 +704,11 @@ func (m *Manager) refreshNoTrackChains() error {
|
||||
}
|
||||
|
||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
return m.createWorkTableFamily(nftables.TableFamilyIPv4)
|
||||
}
|
||||
|
||||
func (m *Manager) createWorkTableFamily(family nftables.TableFamily) (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(family)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
@@ -529,7 +720,7 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
}
|
||||
}
|
||||
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4})
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: family})
|
||||
err = m.rConn.Flush()
|
||||
return table, err
|
||||
}
|
||||
|
||||
@@ -385,10 +385,132 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
err = manager.AddNatRule(pair)
|
||||
require.NoError(t, err, "failed to add NAT rule")
|
||||
|
||||
dnatRule, err := manager.AddDNATRule(fw.ForwardRule{
|
||||
Protocol: fw.ProtocolTCP,
|
||||
DestinationPort: fw.Port{Values: []uint16{8080}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.96.0.2"),
|
||||
TranslatedPort: fw.Port{Values: []uint16{80}},
|
||||
})
|
||||
require.NoError(t, err, "failed to add DNAT rule")
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.DeleteDNATRule(dnatRule), "failed to delete DNAT rule")
|
||||
})
|
||||
|
||||
stdout, stderr = runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
}
|
||||
|
||||
func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("ip6tables-save"); err != nil {
|
||||
t.Skipf("ip6tables-save not available on this system: %v", err)
|
||||
}
|
||||
|
||||
// Seed ip6 tables in the nft backend. Docker may not create them.
|
||||
seedIp6tables(t)
|
||||
|
||||
ifaceMockV6 := &iFaceMock{
|
||||
NameFunc: func() string { return "wt-test" },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.96.0.1"),
|
||||
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
||||
IPv6: netip.MustParseAddr("fd00::1"),
|
||||
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMockV6, iface.DefaultMTU)
|
||||
require.NoError(t, err, "create manager")
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil), "close manager")
|
||||
|
||||
stdout, stderr := runIp6tablesSave(t)
|
||||
verifyIp6tablesOutput(t, stdout, stderr)
|
||||
})
|
||||
|
||||
ip := netip.MustParseAddr("fd00::2")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "add v6 peer filtering rule")
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("fd00:1::/64")},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err, "add v6 route filtering rule")
|
||||
|
||||
err = manager.AddNatRule(fw.RouterPair{
|
||||
Source: fw.Network{Prefix: netip.MustParsePrefix("fd00::/64")},
|
||||
Destination: fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
|
||||
Masquerade: true,
|
||||
})
|
||||
require.NoError(t, err, "add v6 NAT rule")
|
||||
|
||||
dnatRule, err := manager.AddDNATRule(fw.ForwardRule{
|
||||
Protocol: fw.ProtocolTCP,
|
||||
DestinationPort: fw.Port{Values: []uint16{8080}},
|
||||
TranslatedAddress: netip.MustParseAddr("fd00::2"),
|
||||
TranslatedPort: fw.Port{Values: []uint16{80}},
|
||||
})
|
||||
require.NoError(t, err, "add v6 DNAT rule")
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.DeleteDNATRule(dnatRule), "delete v6 DNAT rule")
|
||||
})
|
||||
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
|
||||
stdout, stderr = runIp6tablesSave(t)
|
||||
verifyIp6tablesOutput(t, stdout, stderr)
|
||||
}
|
||||
|
||||
func seedIp6tables(t *testing.T) {
|
||||
t.Helper()
|
||||
for _, tc := range []struct{ table, chain string }{
|
||||
{"filter", "FORWARD"},
|
||||
{"nat", "POSTROUTING"},
|
||||
{"mangle", "FORWARD"},
|
||||
} {
|
||||
add := exec.Command("ip6tables", "-t", tc.table, "-A", tc.chain, "-j", "ACCEPT")
|
||||
require.NoError(t, add.Run(), "seed ip6tables -t %s", tc.table)
|
||||
del := exec.Command("ip6tables", "-t", tc.table, "-D", tc.chain, "-j", "ACCEPT")
|
||||
require.NoError(t, del.Run(), "unseed ip6tables -t %s", tc.table)
|
||||
}
|
||||
}
|
||||
|
||||
func runIp6tablesSave(t *testing.T) (string, string) {
|
||||
t.Helper()
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd := exec.Command("ip6tables-save")
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
require.NoError(t, cmd.Run(), "ip6tables-save failed")
|
||||
return stdout.String(), stderr.String()
|
||||
}
|
||||
|
||||
func verifyIp6tablesOutput(t *testing.T, stdout, stderr string) {
|
||||
t.Helper()
|
||||
require.NotContains(t, stdout, "Table `nat' is incompatible",
|
||||
"ip6tables-save: nat table incompatible. Full output: %s", stdout)
|
||||
require.NotContains(t, stdout, "Table `mangle' is incompatible",
|
||||
"ip6tables-save: mangle table incompatible. Full output: %s", stdout)
|
||||
require.NotContains(t, stdout, "Table `filter' is incompatible",
|
||||
"ip6tables-save: filter table incompatible. Full output: %s", stdout)
|
||||
}
|
||||
|
||||
func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
|
||||
@@ -46,8 +46,10 @@ const (
|
||||
dnatSuffix = "_dnat"
|
||||
snatSuffix = "_snat"
|
||||
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
|
||||
ipv4TCPHeaderSize = 40
|
||||
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
|
||||
ipv6TCPHeaderSize = 60
|
||||
|
||||
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
|
||||
maxPrefixesSet = 1500
|
||||
@@ -72,6 +74,7 @@ type router struct {
|
||||
rules map[string]*nftables.Rule
|
||||
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
|
||||
|
||||
af addrFamily
|
||||
wgIface iFaceMapper
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
legacyManagement bool
|
||||
@@ -84,6 +87,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
|
||||
workTable: workTable,
|
||||
chains: make(map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4),
|
||||
wgIface: wgIface,
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||
mtu: mtu,
|
||||
@@ -142,7 +146,7 @@ func (r *router) Reset() error {
|
||||
func (r *router) removeNatPreroutingRules() error {
|
||||
table := &nftables.Table{
|
||||
Name: tableNat,
|
||||
Family: nftables.TableFamilyIPv4,
|
||||
Family: r.af.tableFamily,
|
||||
}
|
||||
chain := &nftables.Chain{
|
||||
Name: chainNameNatPrerouting,
|
||||
@@ -175,7 +179,7 @@ func (r *router) removeNatPreroutingRules() error {
|
||||
}
|
||||
|
||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
tables, err := r.conn.ListTablesOfFamily(r.af.tableFamily)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list tables: %w", err)
|
||||
}
|
||||
@@ -407,7 +411,7 @@ func (r *router) AddRouteFiltering(
|
||||
|
||||
// Handle protocol
|
||||
if proto != firewall.ProtocolALL {
|
||||
protoNum, err := protoToInt(proto)
|
||||
protoNum, err := r.af.protoNum(proto)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
@@ -467,7 +471,24 @@ func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bo
|
||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||
}
|
||||
|
||||
return getIpSetExprs(ref, isSource)
|
||||
return r.getIpSetExprs(ref, isSource)
|
||||
}
|
||||
|
||||
func (r *router) iptablesProto() iptables.Protocol {
|
||||
if r.af.tableFamily == nftables.TableFamilyIPv6 {
|
||||
return iptables.ProtocolIPv6
|
||||
}
|
||||
return iptables.ProtocolIPv4
|
||||
}
|
||||
|
||||
func (r *router) hasRule(id string) bool {
|
||||
_, ok := r.rules[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *router) hasDNATRule(id string) bool {
|
||||
_, ok := r.rules[id+dnatSuffix]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
@@ -516,10 +537,10 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
|
||||
Table: r.workTable,
|
||||
// required for prefixes
|
||||
Interval: true,
|
||||
KeyType: nftables.TypeIPAddr,
|
||||
KeyType: r.af.setKeyType,
|
||||
}
|
||||
|
||||
elements := convertPrefixesToSet(prefixes)
|
||||
elements := r.convertPrefixesToSet(prefixes)
|
||||
nElements := len(elements)
|
||||
|
||||
maxElements := maxPrefixesSet * 2
|
||||
@@ -552,23 +573,17 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
|
||||
return nfset, nil
|
||||
}
|
||||
|
||||
func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
|
||||
func (r *router) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
|
||||
var elements []nftables.SetElement
|
||||
for _, prefix := range prefixes {
|
||||
// TODO: Implement IPv6 support
|
||||
if prefix.Addr().Is6() {
|
||||
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||
continue
|
||||
}
|
||||
|
||||
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
|
||||
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
|
||||
firstIP := prefix.Addr()
|
||||
lastIP := calculateLastIP(prefix).Next()
|
||||
|
||||
elements = append(elements,
|
||||
// the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247
|
||||
// nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true},
|
||||
// the nft tool also adds a zero-address IntervalEnd element, see https://github.com/google/nftables/issues/247
|
||||
// nftables.SetElement{Key: make([]byte, r.af.addrLen), IntervalEnd: true},
|
||||
nftables.SetElement{Key: firstIP.AsSlice()},
|
||||
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
|
||||
)
|
||||
@@ -578,10 +593,20 @@ func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
|
||||
|
||||
// calculateLastIP determines the last IP in a given prefix.
|
||||
func calculateLastIP(prefix netip.Prefix) netip.Addr {
|
||||
hostMask := ^uint32(0) >> prefix.Masked().Bits()
|
||||
lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask
|
||||
masked := prefix.Masked()
|
||||
if masked.Addr().Is4() {
|
||||
hostMask := ^uint32(0) >> masked.Bits()
|
||||
lastIP := uint32FromNetipAddr(masked.Addr()) | hostMask
|
||||
return netip.AddrFrom4(uint32ToBytes(lastIP))
|
||||
}
|
||||
|
||||
return netip.AddrFrom4(uint32ToBytes(lastIP))
|
||||
// IPv6: set host bits to all 1s
|
||||
b := masked.Addr().As16()
|
||||
bits := masked.Bits()
|
||||
for i := bits; i < 128; i++ {
|
||||
b[i/8] |= 1 << (7 - i%8)
|
||||
}
|
||||
return netip.AddrFrom16(b)
|
||||
}
|
||||
|
||||
// Utility function to convert netip.Addr to uint32.
|
||||
@@ -833,9 +858,12 @@ func (r *router) addPostroutingRules() {
|
||||
}
|
||||
|
||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||
// TODO: Add IPv6 support
|
||||
func (r *router) addMSSClampingRules() error {
|
||||
mss := r.mtu - ipTCPHeaderMinSize
|
||||
overhead := uint16(ipv4TCPHeaderSize)
|
||||
if r.af.tableFamily == nftables.TableFamilyIPv6 {
|
||||
overhead = ipv6TCPHeaderSize
|
||||
}
|
||||
mss := r.mtu - overhead
|
||||
|
||||
exprsOut := []expr.Any{
|
||||
&expr.Meta{
|
||||
@@ -1042,17 +1070,22 @@ func (r *router) acceptFilterTableRules() error {
|
||||
log.Debugf("Used %s to add accept forward and input rules", fw)
|
||||
}()
|
||||
|
||||
// Try iptables first and fallback to nftables if iptables is not available
|
||||
ipt, err := iptables.New()
|
||||
// Try iptables first and fallback to nftables if iptables is not available.
|
||||
// Use the correct protocol (iptables vs ip6tables) for the address family.
|
||||
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
|
||||
if err != nil {
|
||||
// iptables is not available but the filter table exists
|
||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||
|
||||
fw = "nftables"
|
||||
return r.acceptFilterRulesNftables(r.filterTable)
|
||||
}
|
||||
|
||||
return r.acceptFilterRulesIptables(ipt)
|
||||
if err := r.acceptFilterRulesIptables(ipt); err != nil {
|
||||
log.Warnf("iptables failed (table may be incompatible), falling back to nftables: %v", err)
|
||||
fw = "nftables"
|
||||
return r.acceptFilterRulesNftables(r.filterTable)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
@@ -1221,13 +1254,17 @@ func (r *router) removeFilterTableRules() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ipt, err := iptables.New()
|
||||
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
|
||||
if err != nil {
|
||||
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
|
||||
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||
}
|
||||
|
||||
return r.removeAcceptFilterRulesIptables(ipt)
|
||||
if err := r.removeAcceptFilterRulesIptables(ipt); err != nil {
|
||||
log.Debugf("iptables removal failed (table may be incompatible), falling back to nftables: %v", err)
|
||||
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
|
||||
@@ -1294,7 +1331,7 @@ func (r *router) removeExternalChainsRules() error {
|
||||
func (r *router) findExternalChains() []*nftables.Chain {
|
||||
var chains []*nftables.Chain
|
||||
|
||||
families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet}
|
||||
families := []nftables.TableFamily{r.af.tableFamily, nftables.TableFamilyINet}
|
||||
|
||||
for _, family := range families {
|
||||
allChains, err := r.conn.ListChainsOfTableFamily(family)
|
||||
@@ -1318,8 +1355,8 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip all iptables-managed tables in the ip family
|
||||
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
||||
// Skip iptables/ip6tables-managed tables (adding nft-native rules breaks iptables-save compat)
|
||||
if (chain.Table.Family == nftables.TableFamilyIPv4 || chain.Table.Family == nftables.TableFamilyIPv6) && isIptablesTable(chain.Table.Name) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1460,7 +1497,7 @@ func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
protoNum, err := protoToInt(rule.Protocol)
|
||||
protoNum, err := r.af.protoNum(rule.Protocol)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
@@ -1523,7 +1560,7 @@ func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, rule
|
||||
dnatExprs = append(dnatExprs,
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(nftables.TableFamilyIPv4),
|
||||
Family: uint32(r.af.tableFamily),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: regProtoMin,
|
||||
RegProtoMax: regProtoMax,
|
||||
@@ -1619,7 +1656,7 @@ func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule f
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: &nftables.Table{
|
||||
Name: tableNat,
|
||||
Family: nftables.TableFamilyIPv4,
|
||||
Family: r.af.tableFamily,
|
||||
},
|
||||
Chain: &nftables.Chain{
|
||||
Name: chainNameNatPrerouting,
|
||||
@@ -1654,8 +1691,8 @@ func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 16,
|
||||
Len: 4,
|
||||
Offset: r.af.dstAddrOffset,
|
||||
Len: r.af.addrLen,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
@@ -1733,7 +1770,7 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
|
||||
}
|
||||
|
||||
elements := convertPrefixesToSet(prefixes)
|
||||
elements := r.convertPrefixesToSet(prefixes)
|
||||
if err := r.conn.SetAddElements(nfset, elements); err != nil {
|
||||
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
|
||||
}
|
||||
@@ -1755,7 +1792,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
||||
return nil
|
||||
}
|
||||
|
||||
protoNum, err := protoToInt(protocol)
|
||||
protoNum, err := r.af.protoNum(protocol)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
@@ -1786,7 +1823,11 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
||||
},
|
||||
}
|
||||
|
||||
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
|
||||
bits := 32
|
||||
if localAddr.Is6() {
|
||||
bits = 128
|
||||
}
|
||||
exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...)
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
@@ -1799,7 +1840,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
||||
},
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(nftables.TableFamilyIPv4),
|
||||
Family: uint32(r.af.tableFamily),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: 2,
|
||||
RegProtoMax: 0,
|
||||
@@ -1868,45 +1909,44 @@ func (r *router) applyNetwork(
|
||||
}
|
||||
|
||||
if network.IsPrefix() {
|
||||
return applyPrefix(network.Prefix, isSource), nil
|
||||
return r.applyPrefix(network.Prefix, isSource), nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// applyPrefix generates nftables expressions for a CIDR prefix
|
||||
func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
|
||||
// dst offset
|
||||
offset := uint32(16)
|
||||
func (r *router) applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
|
||||
// dst offset by default
|
||||
offset := r.af.dstAddrOffset
|
||||
if isSource {
|
||||
// src offset
|
||||
offset = 12
|
||||
offset = r.af.srcAddrOffset
|
||||
}
|
||||
|
||||
ones := prefix.Bits()
|
||||
// 0.0.0.0/0 doesn't need extra expressions
|
||||
// unspecified address (/0) doesn't need extra expressions
|
||||
if ones == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
mask := net.CIDRMask(ones, 32)
|
||||
mask := net.CIDRMask(ones, r.af.totalBits)
|
||||
xor := make([]byte, r.af.addrLen)
|
||||
|
||||
return []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: offset,
|
||||
Len: 4,
|
||||
Len: r.af.addrLen,
|
||||
},
|
||||
// netmask
|
||||
&expr.Bitwise{
|
||||
DestRegister: 1,
|
||||
SourceRegister: 1,
|
||||
Len: 4,
|
||||
Len: r.af.addrLen,
|
||||
Mask: mask,
|
||||
Xor: []byte{0, 0, 0, 0},
|
||||
Xor: xor,
|
||||
},
|
||||
// net address
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
@@ -1989,13 +2029,12 @@ func getCtNewExprs() []expr.Any {
|
||||
}
|
||||
}
|
||||
|
||||
func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
|
||||
|
||||
// dst offset
|
||||
offset := uint32(16)
|
||||
func (r *router) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
|
||||
// dst offset by default
|
||||
offset := r.af.dstAddrOffset
|
||||
if isSource {
|
||||
// src offset
|
||||
offset = 12
|
||||
offset = r.af.srcAddrOffset
|
||||
}
|
||||
|
||||
return []expr.Any{
|
||||
@@ -2003,7 +2042,7 @@ func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: offset,
|
||||
Len: 4,
|
||||
Len: r.af.addrLen,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
|
||||
@@ -90,8 +90,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
}
|
||||
|
||||
// Build CIDR matching expressions
|
||||
sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true)
|
||||
destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false)
|
||||
testRouter := &router{af: afIPv4}
|
||||
sourceExp := testRouter.applyPrefix(testCase.InputPair.Source.Prefix, true)
|
||||
destExp := testRouter.applyPrefix(testCase.InputPair.Destination.Prefix, false)
|
||||
|
||||
// Combine all expressions in the correct order
|
||||
// nolint:gocritic
|
||||
@@ -508,6 +509,136 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNftablesCreateIpSet_IPv6(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
workTable, err := createWorkTableIPv6()
|
||||
require.NoError(t, err, "Failed to create v6 work table")
|
||||
defer deleteWorkTableIPv6()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
require.NoError(t, r.init(workTable))
|
||||
defer func() {
|
||||
require.NoError(t, r.Reset(), "Failed to reset router")
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sources []netip.Prefix
|
||||
expected []netip.Prefix
|
||||
}{
|
||||
{
|
||||
name: "Single IPv6",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("2001:db8::1/128")},
|
||||
},
|
||||
{
|
||||
name: "Multiple IPv6 Subnets",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("fd00::/64"),
|
||||
netip.MustParsePrefix("2001:db8::/48"),
|
||||
netip.MustParsePrefix("fe80::/10"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Overlapping IPv6",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("fd00::/48"),
|
||||
netip.MustParsePrefix("fd00::/64"),
|
||||
netip.MustParsePrefix("fd00::1/128"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("fd00::/48"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Mixed prefix lengths",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("2001:db8:1::/48"),
|
||||
netip.MustParsePrefix("2001:db8:2::1/128"),
|
||||
netip.MustParsePrefix("fd00:abcd::/32"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setName := firewall.NewPrefixSet(tt.sources).HashedName()
|
||||
set, err := r.createIpSet(setName, setInput{prefixes: tt.sources})
|
||||
require.NoError(t, err, "Failed to create IPv6 set")
|
||||
require.NotNil(t, set)
|
||||
|
||||
assert.Equal(t, setName, set.Name)
|
||||
assert.True(t, set.Interval)
|
||||
assert.Equal(t, nftables.TypeIP6Addr, set.KeyType)
|
||||
|
||||
fetchedSet, err := r.conn.GetSetByName(r.workTable, setName)
|
||||
require.NoError(t, err, "Failed to fetch created set")
|
||||
|
||||
elements, err := r.conn.GetSetElements(fetchedSet)
|
||||
require.NoError(t, err, "Failed to get set elements")
|
||||
|
||||
uniquePrefixes := make(map[string]bool)
|
||||
for _, elem := range elements {
|
||||
if !elem.IntervalEnd && len(elem.Key) == 16 {
|
||||
ip := netip.AddrFrom16([16]byte(elem.Key))
|
||||
uniquePrefixes[ip.String()] = true
|
||||
}
|
||||
}
|
||||
|
||||
expectedCount := len(tt.expected)
|
||||
if expectedCount == 0 {
|
||||
expectedCount = len(tt.sources)
|
||||
}
|
||||
assert.Equal(t, expectedCount, len(uniquePrefixes), "unique prefix count mismatch")
|
||||
|
||||
r.conn.DelSet(set)
|
||||
require.NoError(t, r.conn.Flush())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createWorkTableIPv6() (*nftables.Table, error) {
|
||||
sConn, err := nftables.New(nftables.AsLasting())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, t := range tables {
|
||||
if t.Name == tableNameNetbird {
|
||||
sConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv6})
|
||||
err = sConn.Flush()
|
||||
return table, err
|
||||
}
|
||||
|
||||
func deleteWorkTableIPv6() {
|
||||
sConn, err := nftables.New(nftables.AsLasting())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, t := range tables {
|
||||
if t.Name == tableNameNetbird {
|
||||
sConn.DelTable(t)
|
||||
_ = sConn.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
|
||||
t.Helper()
|
||||
|
||||
@@ -627,7 +758,7 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
|
||||
|
||||
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
|
||||
var metaFound, cmpFound bool
|
||||
expectedProto, _ := protoToInt(proto)
|
||||
expectedProto, _ := afIPv4.protoNum(proto)
|
||||
for _, e := range exprs {
|
||||
switch ex := e.(type) {
|
||||
case *expr.Meta:
|
||||
@@ -854,3 +985,55 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
||||
}
|
||||
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
|
||||
}
|
||||
|
||||
func TestCalculateLastIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
prefix string
|
||||
want string
|
||||
}{
|
||||
{"10.0.0.0/24", "10.0.0.255"},
|
||||
{"10.0.0.0/32", "10.0.0.0"},
|
||||
{"0.0.0.0/0", "255.255.255.255"},
|
||||
{"192.168.1.0/28", "192.168.1.15"},
|
||||
{"fd00::/64", "fd00::ffff:ffff:ffff:ffff"},
|
||||
{"fd00::/128", "fd00::"},
|
||||
{"2001:db8::/48", "2001:db8:0:ffff:ffff:ffff:ffff:ffff"},
|
||||
{"::/0", "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.prefix, func(t *testing.T) {
|
||||
prefix := netip.MustParsePrefix(tt.prefix)
|
||||
got := calculateLastIP(prefix)
|
||||
assert.Equal(t, tt.want, got.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertPrefixesToSet_IPv6(t *testing.T) {
|
||||
r := &router{af: afIPv6}
|
||||
prefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("fd00::/64"),
|
||||
netip.MustParsePrefix("2001:db8::1/128"),
|
||||
}
|
||||
|
||||
elements := r.convertPrefixesToSet(prefixes)
|
||||
|
||||
// Each prefix produces 2 elements (start + end)
|
||||
require.Len(t, elements, 4)
|
||||
|
||||
// fd00::/64 start
|
||||
assert.Equal(t, netip.MustParseAddr("fd00::").As16(), [16]byte(elements[0].Key))
|
||||
assert.False(t, elements[0].IntervalEnd)
|
||||
|
||||
// fd00::/64 end (fd00:0:0:1::, one past the last)
|
||||
assert.Equal(t, netip.MustParseAddr("fd00:0:0:1::").As16(), [16]byte(elements[1].Key))
|
||||
assert.True(t, elements[1].IntervalEnd)
|
||||
|
||||
// 2001:db8::1/128 start
|
||||
assert.Equal(t, netip.MustParseAddr("2001:db8::1").As16(), [16]byte(elements[2].Key))
|
||||
assert.False(t, elements[2].IntervalEnd)
|
||||
|
||||
// 2001:db8::1/128 end (2001:db8::2)
|
||||
assert.Equal(t, netip.MustParseAddr("2001:db8::2").As16(), [16]byte(elements[3].Key))
|
||||
assert.True(t, elements[3].IntervalEnd)
|
||||
}
|
||||
|
||||
@@ -5,8 +5,10 @@ import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -29,15 +31,20 @@ func (m *Manager) Close(*statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !isFirewallRuleActive(firewallRuleName) {
|
||||
return nil
|
||||
var merr *multierror.Error
|
||||
if isFirewallRuleActive(firewallRuleName) {
|
||||
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove windows firewall rule: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
|
||||
return fmt.Errorf("couldn't remove windows firewall: %w", err)
|
||||
if isFirewallRuleActive(firewallRuleName + "-v6") {
|
||||
if err := manageFirewallRule(firewallRuleName+"-v6", deleteRule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove windows v6 firewall rule: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
@@ -46,17 +53,33 @@ func (m *Manager) AllowNetbird() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if isFirewallRuleActive(firewallRuleName) {
|
||||
return nil
|
||||
if !isFirewallRuleActive(firewallRuleName) {
|
||||
if err := manageFirewallRule(firewallRuleName,
|
||||
addRule,
|
||||
"dir=in",
|
||||
"enable=yes",
|
||||
"action=allow",
|
||||
"profile=any",
|
||||
"localip="+m.wgIface.Address().IP.String(),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return manageFirewallRule(firewallRuleName,
|
||||
addRule,
|
||||
"dir=in",
|
||||
"enable=yes",
|
||||
"action=allow",
|
||||
"profile=any",
|
||||
"localip="+m.wgIface.Address().IP.String(),
|
||||
)
|
||||
|
||||
if v6 := m.wgIface.Address().IPv6; v6.IsValid() && !isFirewallRuleActive(firewallRuleName+"-v6") {
|
||||
if err := manageFirewallRule(firewallRuleName+"-v6",
|
||||
addRule,
|
||||
"dir=in",
|
||||
"enable=yes",
|
||||
"action=allow",
|
||||
"profile=any",
|
||||
"localip="+v6.String(),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func manageFirewallRule(ruleName string, action action, extraArgs ...string) error {
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -64,5 +65,7 @@ type ConnKey struct {
|
||||
}
|
||||
|
||||
func (c ConnKey) String() string {
|
||||
return fmt.Sprintf("%s:%d → %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
||||
return net.JoinHostPort(c.SrcIP.Unmap().String(), strconv.Itoa(int(c.SrcPort))) +
|
||||
" → " +
|
||||
net.JoinHostPort(c.DstIP.Unmap().String(), strconv.Itoa(int(c.DstPort)))
|
||||
}
|
||||
|
||||
@@ -21,9 +21,10 @@ const (
|
||||
// ICMPCleanupInterval is how often we check for stale ICMP connections
|
||||
ICMPCleanupInterval = 15 * time.Second
|
||||
|
||||
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info,
|
||||
// which includes the IP header (20 bytes) and transport header (8 bytes)
|
||||
MaxICMPPayloadLength = 28
|
||||
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info.
|
||||
// IPv4: 20-byte header + 8-byte transport = 28 bytes.
|
||||
// IPv6: 40-byte header + 8-byte transport = 48 bytes.
|
||||
MaxICMPPayloadLength = 48
|
||||
)
|
||||
|
||||
// ICMPConnKey uniquely identifies an ICMP connection
|
||||
@@ -74,32 +75,64 @@ func (info ICMPInfo) String() string {
|
||||
return info.TypeCode.String()
|
||||
}
|
||||
|
||||
// isErrorMessage returns true if this ICMP type carries original packet info
|
||||
// isErrorMessage returns true if this ICMP type carries original packet info.
|
||||
// Covers both ICMPv4 and ICMPv6 error types. Without a family field we match
|
||||
// both sets; type 3 overlaps (v4 DestUnreachable / v6 TimeExceeded) so it's
|
||||
// kept as a literal.
|
||||
func (info ICMPInfo) isErrorMessage() bool {
|
||||
typ := info.TypeCode.Type()
|
||||
return typ == 3 || // Destination Unreachable
|
||||
typ == 5 || // Redirect
|
||||
typ == 11 || // Time Exceeded
|
||||
typ == 12 // Parameter Problem
|
||||
// ICMPv4 error types
|
||||
if typ == layers.ICMPv4TypeDestinationUnreachable ||
|
||||
typ == layers.ICMPv4TypeRedirect ||
|
||||
typ == layers.ICMPv4TypeTimeExceeded ||
|
||||
typ == layers.ICMPv4TypeParameterProblem {
|
||||
return true
|
||||
}
|
||||
// ICMPv6 error types (type 3 already matched above as v4 DestUnreachable)
|
||||
if typ == layers.ICMPv6TypeDestinationUnreachable ||
|
||||
typ == layers.ICMPv6TypePacketTooBig ||
|
||||
typ == layers.ICMPv6TypeParameterProblem {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// parseOriginalPacket extracts info about the original packet from ICMP payload
|
||||
func (info ICMPInfo) parseOriginalPacket() string {
|
||||
if info.PayloadLen < MaxICMPPayloadLength {
|
||||
if info.PayloadLen == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// TODO: handle IPv6
|
||||
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
|
||||
version := (info.PayloadData[0] >> 4) & 0xF
|
||||
|
||||
var protocol uint8
|
||||
var srcIP, dstIP net.IP
|
||||
var transportData []byte
|
||||
|
||||
switch version {
|
||||
case 4:
|
||||
// 20-byte IPv4 header + 8-byte transport minimum
|
||||
if info.PayloadLen < 28 {
|
||||
return ""
|
||||
}
|
||||
protocol = info.PayloadData[9]
|
||||
srcIP = net.IP(info.PayloadData[12:16])
|
||||
dstIP = net.IP(info.PayloadData[16:20])
|
||||
transportData = info.PayloadData[20:]
|
||||
case 6:
|
||||
// 40-byte IPv6 header + 8-byte transport minimum
|
||||
if info.PayloadLen < 48 {
|
||||
return ""
|
||||
}
|
||||
// Next Header field in IPv6 header
|
||||
protocol = info.PayloadData[6]
|
||||
srcIP = net.IP(info.PayloadData[8:24])
|
||||
dstIP = net.IP(info.PayloadData[24:40])
|
||||
transportData = info.PayloadData[40:]
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
||||
protocol := info.PayloadData[9]
|
||||
srcIP := net.IP(info.PayloadData[12:16])
|
||||
dstIP := net.IP(info.PayloadData[16:20])
|
||||
|
||||
transportData := info.PayloadData[20:]
|
||||
|
||||
switch nftypes.Protocol(protocol) {
|
||||
case nftypes.TCP:
|
||||
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
||||
@@ -247,9 +280,10 @@ func (t *ICMPTracker) track(
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request.
|
||||
// Accepts both ICMPv4 (type 0) and ICMPv6 (type 129) echo replies.
|
||||
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
|
||||
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
||||
if icmpType != uint8(layers.ICMPv4TypeEchoReply) && icmpType != uint8(layers.ICMPv6TypeEchoReply) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -301,6 +335,13 @@ func (t *ICMPTracker) cleanup() {
|
||||
}
|
||||
}
|
||||
|
||||
func icmpProtocolForAddr(ip netip.Addr) nftypes.Protocol {
|
||||
if ip.Is6() {
|
||||
return nftypes.ICMPv6
|
||||
}
|
||||
return nftypes.ICMP
|
||||
}
|
||||
|
||||
// Close stops the cleanup routine and releases resources
|
||||
func (t *ICMPTracker) Close() {
|
||||
t.tickerCancel()
|
||||
@@ -316,7 +357,7 @@ func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []
|
||||
Type: typ,
|
||||
RuleID: ruleID,
|
||||
Direction: conn.Direction,
|
||||
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
|
||||
Protocol: icmpProtocolForAddr(conn.SourceIP),
|
||||
SourceIP: conn.SourceIP,
|
||||
DestIP: conn.DestIP,
|
||||
ICMPType: conn.ICMPType,
|
||||
@@ -334,7 +375,7 @@ func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Ad
|
||||
Type: nftypes.TypeStart,
|
||||
RuleID: ruleID,
|
||||
Direction: direction,
|
||||
Protocol: nftypes.ICMP,
|
||||
Protocol: icmpProtocolForAddr(srcIP),
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
ICMPType: typ,
|
||||
|
||||
@@ -35,8 +35,10 @@ import (
|
||||
const (
|
||||
layerTypeAll = 255
|
||||
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
// ipv4TCPHeaderMinSize represents minimum IPv4 (20) + TCP (20) header size for MSS calculation
|
||||
ipv4TCPHeaderMinSize = 40
|
||||
// ipv6TCPHeaderMinSize represents minimum IPv6 (40) + TCP (20) header size for MSS calculation
|
||||
ipv6TCPHeaderMinSize = 60
|
||||
)
|
||||
|
||||
// serviceKey represents a protocol/port combination for netstack service registry
|
||||
@@ -137,9 +139,10 @@ type Manager struct {
|
||||
netstackServices map[serviceKey]struct{}
|
||||
netstackServiceMutex sync.RWMutex
|
||||
|
||||
mtu uint16
|
||||
mssClampValue uint16
|
||||
mssClampEnabled bool
|
||||
mtu uint16
|
||||
mssClampValueIPv4 uint16
|
||||
mssClampValueIPv6 uint16
|
||||
mssClampEnabled bool
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@@ -152,11 +155,28 @@ type decoder struct {
|
||||
icmp4 layers.ICMPv4
|
||||
icmp6 layers.ICMPv6
|
||||
decoded []gopacket.LayerType
|
||||
parser *gopacket.DecodingLayerParser
|
||||
parser4 *gopacket.DecodingLayerParser
|
||||
parser6 *gopacket.DecodingLayerParser
|
||||
|
||||
dnatOrigPort uint16
|
||||
}
|
||||
|
||||
// decodePacket decodes packet data using the appropriate parser based on IP version.
|
||||
func (d *decoder) decodePacket(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return errors.New("empty packet")
|
||||
}
|
||||
version := data[0] >> 4
|
||||
switch version {
|
||||
case 4:
|
||||
return d.parser4.DecodeLayers(data, &d.decoded)
|
||||
case 6:
|
||||
return d.parser6.DecodeLayers(data, &d.decoded)
|
||||
default:
|
||||
return fmt.Errorf("unknown IP version %d", version)
|
||||
}
|
||||
}
|
||||
|
||||
// Create userspace firewall manager constructor
|
||||
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||
return create(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||
@@ -214,11 +234,17 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
d.parser4.IgnoreUnsupported = true
|
||||
|
||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv6,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser6.IgnoreUnsupported = true
|
||||
return d
|
||||
},
|
||||
},
|
||||
@@ -244,7 +270,8 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
|
||||
if !disableMSSClamping {
|
||||
m.mssClampEnabled = true
|
||||
m.mssClampValue = mtu - ipTCPHeaderMinSize
|
||||
m.mssClampValueIPv4 = mtu - ipv4TCPHeaderMinSize
|
||||
m.mssClampValueIPv6 = mtu - ipv6TCPHeaderMinSize
|
||||
}
|
||||
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||
@@ -271,9 +298,14 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, e
|
||||
wgPrefix := iface.Address().Network
|
||||
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||
|
||||
sources := []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
if v6 := iface.Address().IPv6Net; v6.IsValid() {
|
||||
sources = append(sources, netip.PrefixFrom(netip.IPv6Unspecified(), 0))
|
||||
}
|
||||
|
||||
rule, err := m.addRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
||||
sources,
|
||||
firewall.Network{Prefix: wgPrefix},
|
||||
firewall.ProtocolALL,
|
||||
nil,
|
||||
@@ -281,7 +313,22 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, e
|
||||
firewall.ActionDrop,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("block wg nte : %w", err)
|
||||
return nil, fmt.Errorf("block wg v4 net: %w", err)
|
||||
}
|
||||
|
||||
if v6Net := iface.Address().IPv6Net; v6Net.IsValid() {
|
||||
log.Debugf("blocking invalid routed traffic for %s", v6Net)
|
||||
if _, err := m.addRouteFiltering(
|
||||
nil,
|
||||
sources,
|
||||
firewall.Network{Prefix: v6Net},
|
||||
firewall.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
firewall.ActionDrop,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("block wg v6 net: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Block networks that we're a client of
|
||||
@@ -498,7 +545,7 @@ func (m *Manager) addRouteFiltering(
|
||||
mgmtId: id,
|
||||
sources: sources,
|
||||
dstSet: destination.Set,
|
||||
protoLayer: protoToLayer(proto, layers.LayerTypeIPv4),
|
||||
protoLayer: protoToLayer(proto, ipLayerFromPrefix(destination.Prefix)),
|
||||
srcPort: sPort,
|
||||
dstPort: dPort,
|
||||
action: action,
|
||||
@@ -650,11 +697,7 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
}
|
||||
|
||||
destinations := matches[0].destinations
|
||||
for _, prefix := range prefixes {
|
||||
if prefix.Addr().Is4() {
|
||||
destinations = append(destinations, prefix)
|
||||
}
|
||||
}
|
||||
destinations = append(destinations, prefixes...)
|
||||
|
||||
slices.SortFunc(destinations, func(a, b netip.Prefix) int {
|
||||
cmp := a.Addr().Compare(b.Addr())
|
||||
@@ -693,7 +736,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -774,12 +817,28 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
var mssClampValue uint16
|
||||
var ipHeaderSize int
|
||||
switch d.decoded[0] {
|
||||
case layers.LayerTypeIPv4:
|
||||
mssClampValue = m.mssClampValueIPv4
|
||||
ipHeaderSize = int(d.ip4.IHL) * 4
|
||||
if ipHeaderSize < 20 {
|
||||
return false
|
||||
}
|
||||
case layers.LayerTypeIPv6:
|
||||
mssClampValue = m.mssClampValueIPv6
|
||||
ipHeaderSize = 40
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
mssOptionIndex := -1
|
||||
var currentMSS uint16
|
||||
for i, opt := range d.tcp.Options {
|
||||
if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 {
|
||||
currentMSS = binary.BigEndian.Uint16(opt.OptionData)
|
||||
if currentMSS > m.mssClampValue {
|
||||
if currentMSS > mssClampValue {
|
||||
mssOptionIndex = i
|
||||
break
|
||||
}
|
||||
@@ -790,20 +849,15 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
ipHeaderSize := int(d.ip4.IHL) * 4
|
||||
if ipHeaderSize < 20 {
|
||||
if !m.updateMSSOption(packetData, d, mssOptionIndex, mssClampValue, ipHeaderSize) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) {
|
||||
return false
|
||||
}
|
||||
|
||||
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
|
||||
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool {
|
||||
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex int, mssClampValue uint16, ipHeaderSize int) bool {
|
||||
tcpHeaderStart := ipHeaderSize
|
||||
tcpOptionsStart := tcpHeaderStart + 20
|
||||
|
||||
@@ -818,7 +872,7 @@ func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex,
|
||||
}
|
||||
|
||||
mssValueOffset := optOffset + 2
|
||||
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue)
|
||||
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], mssClampValue)
|
||||
|
||||
m.recalculateTCPChecksum(packetData, d, tcpHeaderStart)
|
||||
return true
|
||||
@@ -828,18 +882,32 @@ func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeade
|
||||
tcpLayer := packetData[tcpHeaderStart:]
|
||||
tcpLength := len(packetData) - tcpHeaderStart
|
||||
|
||||
// Zero out existing checksum
|
||||
tcpLayer[16] = 0
|
||||
tcpLayer[17] = 0
|
||||
|
||||
// Build pseudo-header checksum based on IP version
|
||||
var pseudoSum uint32
|
||||
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
|
||||
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
|
||||
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
|
||||
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
|
||||
pseudoSum += uint32(d.ip4.Protocol)
|
||||
pseudoSum += uint32(tcpLength)
|
||||
switch d.decoded[0] {
|
||||
case layers.LayerTypeIPv4:
|
||||
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
|
||||
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
|
||||
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
|
||||
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
|
||||
pseudoSum += uint32(d.ip4.Protocol)
|
||||
pseudoSum += uint32(tcpLength)
|
||||
case layers.LayerTypeIPv6:
|
||||
for i := 0; i < 16; i += 2 {
|
||||
pseudoSum += uint32(d.ip6.SrcIP[i])<<8 | uint32(d.ip6.SrcIP[i+1])
|
||||
}
|
||||
for i := 0; i < 16; i += 2 {
|
||||
pseudoSum += uint32(d.ip6.DstIP[i])<<8 | uint32(d.ip6.DstIP[i+1])
|
||||
}
|
||||
pseudoSum += uint32(tcpLength)
|
||||
pseudoSum += uint32(layers.IPProtocolTCP)
|
||||
}
|
||||
|
||||
var sum = pseudoSum
|
||||
sum := pseudoSum
|
||||
for i := 0; i < tcpLength-1; i += 2 {
|
||||
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
|
||||
}
|
||||
@@ -877,6 +945,9 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData
|
||||
}
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
|
||||
case layers.LayerTypeICMPv6:
|
||||
id, tc := icmpv6EchoFields(d)
|
||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, id, tc, d.icmp6.Payload, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -890,6 +961,9 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
||||
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort)
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
|
||||
case layers.LayerTypeICMPv6:
|
||||
id, tc := icmpv6EchoFields(d)
|
||||
m.icmpTracker.TrackInbound(srcIP, dstIP, id, tc, ruleID, d.icmp6.Payload, size)
|
||||
}
|
||||
|
||||
d.dnatOrigPort = 0
|
||||
@@ -949,15 +1023,19 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||
|
||||
// TODO: pass fragments of routed packets to forwarder
|
||||
if fragment {
|
||||
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
|
||||
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
|
||||
if d.decoded[0] == layers.LayerTypeIPv4 {
|
||||
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
|
||||
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
|
||||
} else {
|
||||
m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TODO: optimize port DNAT by caching matched rules in conntrack
|
||||
if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated {
|
||||
// Re-decode after port DNAT translation to update port information
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
m.logger.Error1("failed to re-decode packet after port DNAT: %v", err)
|
||||
return true
|
||||
}
|
||||
@@ -966,7 +1044,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||
|
||||
if translated := m.translateInboundReverse(packetData, d); translated {
|
||||
// Re-decode after translation to get original addresses
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err)
|
||||
return true
|
||||
}
|
||||
@@ -1098,6 +1176,48 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
||||
return true
|
||||
}
|
||||
|
||||
// icmpv6EchoFields extracts the echo identifier from an ICMPv6 packet and maps
|
||||
// the ICMPv6 type code to an ICMPv4TypeCode so the ICMP conntrack can handle
|
||||
// both families uniformly. The echo ID is in the first two payload bytes.
|
||||
func icmpv6EchoFields(d *decoder) (id uint16, tc layers.ICMPv4TypeCode) {
|
||||
if len(d.icmp6.Payload) >= 2 {
|
||||
id = uint16(d.icmp6.Payload[0])<<8 | uint16(d.icmp6.Payload[1])
|
||||
}
|
||||
// Map ICMPv6 echo types to ICMPv4 equivalents for unified tracking.
|
||||
switch d.icmp6.TypeCode.Type() {
|
||||
case layers.ICMPv6TypeEchoRequest:
|
||||
tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0)
|
||||
case layers.ICMPv6TypeEchoReply:
|
||||
tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoReply, 0)
|
||||
default:
|
||||
tc = layers.CreateICMPv4TypeCode(d.icmp6.TypeCode.Type(), d.icmp6.TypeCode.Code())
|
||||
}
|
||||
return id, tc
|
||||
}
|
||||
|
||||
// protoLayerMatches checks if a packet's protocol layer matches a rule's expected
|
||||
// protocol layer. ICMPv4 and ICMPv6 are treated as equivalent when matching
|
||||
// ICMP rules since management sends a single ICMP rule for both families.
|
||||
func protoLayerMatches(ruleLayer, packetLayer gopacket.LayerType) bool {
|
||||
if ruleLayer == packetLayer {
|
||||
return true
|
||||
}
|
||||
if ruleLayer == layers.LayerTypeICMPv4 && packetLayer == layers.LayerTypeICMPv6 {
|
||||
return true
|
||||
}
|
||||
if ruleLayer == layers.LayerTypeICMPv6 && packetLayer == layers.LayerTypeICMPv4 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func ipLayerFromPrefix(p netip.Prefix) gopacket.LayerType {
|
||||
if p.Addr().Is6() {
|
||||
return layers.LayerTypeIPv6
|
||||
}
|
||||
return layers.LayerTypeIPv4
|
||||
}
|
||||
|
||||
func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType {
|
||||
switch proto {
|
||||
case firewall.ProtocolTCP:
|
||||
@@ -1121,8 +1241,10 @@ func getProtocolFromPacket(d *decoder) nftypes.Protocol {
|
||||
return nftypes.TCP
|
||||
case layers.LayerTypeUDP:
|
||||
return nftypes.UDP
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
case layers.LayerTypeICMPv4:
|
||||
return nftypes.ICMP
|
||||
case layers.LayerTypeICMPv6:
|
||||
return nftypes.ICMPv6
|
||||
default:
|
||||
return nftypes.ProtocolUnknown
|
||||
}
|
||||
@@ -1143,7 +1265,7 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
|
||||
// It returns true, false if the packet is valid and not a fragment.
|
||||
// It returns true, true if the packet is a fragment and valid.
|
||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
m.logger.Trace1("couldn't decode packet, err: %s", err)
|
||||
return false, false
|
||||
}
|
||||
@@ -1156,10 +1278,18 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
|
||||
}
|
||||
|
||||
// Fragments are also valid
|
||||
if l == 1 && d.decoded[0] == layers.LayerTypeIPv4 {
|
||||
ip4 := d.ip4
|
||||
if ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0 {
|
||||
return true, true
|
||||
if l == 1 {
|
||||
switch d.decoded[0] {
|
||||
case layers.LayerTypeIPv4:
|
||||
if d.ip4.Flags&layers.IPv4MoreFragments != 0 || d.ip4.FragOffset != 0 {
|
||||
return true, true
|
||||
}
|
||||
case layers.LayerTypeIPv6:
|
||||
// IPv6 uses Fragment extension header (NextHeader=44). If gopacket
|
||||
// only decoded the IPv6 layer, the transport is in a fragment.
|
||||
if d.ip6.NextHeader == layers.IPProtocolIPv6Fragment {
|
||||
return true, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1197,21 +1327,34 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr,
|
||||
size,
|
||||
)
|
||||
|
||||
// TODO: ICMPv6
|
||||
case layers.LayerTypeICMPv6:
|
||||
id, _ := icmpv6EchoFields(d)
|
||||
return m.icmpTracker.IsValidInbound(
|
||||
srcIP,
|
||||
dstIP,
|
||||
id,
|
||||
d.icmp6.TypeCode.Type(),
|
||||
size,
|
||||
)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isSpecialICMP returns true if the packet is a special ICMP packet that should be allowed
|
||||
// isSpecialICMP returns true if the packet is a special ICMP error packet that should be allowed.
|
||||
func (m *Manager) isSpecialICMP(d *decoder) bool {
|
||||
if d.decoded[1] != layers.LayerTypeICMPv4 {
|
||||
return false
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeICMPv4:
|
||||
icmpType := d.icmp4.TypeCode.Type()
|
||||
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
|
||||
icmpType == layers.ICMPv4TypeTimeExceeded
|
||||
case layers.LayerTypeICMPv6:
|
||||
icmpType := d.icmp6.TypeCode.Type()
|
||||
return icmpType == layers.ICMPv6TypeDestinationUnreachable ||
|
||||
icmpType == layers.ICMPv6TypePacketTooBig ||
|
||||
icmpType == layers.ICMPv6TypeTimeExceeded
|
||||
}
|
||||
|
||||
icmpType := d.icmp4.TypeCode.Type()
|
||||
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
|
||||
icmpType == layers.ICMPv4TypeTimeExceeded
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte) ([]byte, bool) {
|
||||
@@ -1268,7 +1411,7 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
|
||||
if payloadLayer != rule.protoLayer {
|
||||
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1309,8 +1452,7 @@ func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.Lay
|
||||
}
|
||||
|
||||
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool {
|
||||
// TODO: handle ipv6 vs ipv4 icmp rules
|
||||
if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer {
|
||||
if rule.protoLayer != layerTypeAll && !protoLayerMatches(rule.protoLayer, protoLayer) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1515,7 +1657,8 @@ func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool {
|
||||
}
|
||||
|
||||
// traffic to our other local interfaces (not NetBird IP) - always forward
|
||||
if dstIP != m.wgIface.Address().IP {
|
||||
addr := m.wgIface.Address()
|
||||
if dstIP != addr.IP && (!addr.IPv6.IsValid() || dstIP != addr.IPv6) {
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -1023,7 +1023,8 @@ func BenchmarkMSSClamping(b *testing.B) {
|
||||
}()
|
||||
|
||||
manager.mssClampEnabled = true
|
||||
manager.mssClampValue = 1240
|
||||
manager.mssClampValueIPv4 = 1240
|
||||
manager.mssClampValueIPv6 = 1220
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.2")
|
||||
dstIP := net.ParseIP("8.8.8.8")
|
||||
@@ -1088,7 +1089,8 @@ func BenchmarkMSSClampingOverhead(b *testing.B) {
|
||||
|
||||
manager.mssClampEnabled = sc.enabled
|
||||
if sc.enabled {
|
||||
manager.mssClampValue = 1240
|
||||
manager.mssClampValueIPv4 = 1240
|
||||
manager.mssClampValueIPv6 = 1220
|
||||
}
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.2")
|
||||
@@ -1141,7 +1143,8 @@ func BenchmarkMSSClampingMemory(b *testing.B) {
|
||||
}()
|
||||
|
||||
manager.mssClampEnabled = true
|
||||
manager.mssClampValue = 1240
|
||||
manager.mssClampValueIPv4 = 1240
|
||||
manager.mssClampValueIPv6 = 1220
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.2")
|
||||
dstIP := net.ParseIP("8.8.8.8")
|
||||
|
||||
@@ -539,53 +539,236 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerACLFilteringIPv6(t *testing.T) {
|
||||
localIP := netip.MustParseAddr("100.10.0.100")
|
||||
localIPv6 := netip.MustParseAddr("fd00::100")
|
||||
wgNet := netip.MustParsePrefix("100.10.0.0/16")
|
||||
wgNetV6 := netip.MustParsePrefix("fd00::/64")
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: localIP,
|
||||
Network: wgNet,
|
||||
IPv6: localIPv6,
|
||||
IPv6Net: wgNetV6,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||
|
||||
err = manager.UpdateLocalIPs()
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
srcIP string
|
||||
dstIP string
|
||||
proto fw.Protocol
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
ruleIP string
|
||||
ruleProto fw.Protocol
|
||||
ruleDstPort *fw.Port
|
||||
ruleAction fw.Action
|
||||
shouldBeBlocked bool
|
||||
}{
|
||||
{
|
||||
name: "IPv6: allow TCP from peer",
|
||||
srcIP: "fd00::1",
|
||||
dstIP: "fd00::100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
ruleIP: "fd00::1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6: allow UDP from peer",
|
||||
srcIP: "fd00::1",
|
||||
dstIP: "fd00::100",
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 53,
|
||||
ruleIP: "fd00::1",
|
||||
ruleProto: fw.ProtocolUDP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{53}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6: allow ICMPv6 from peer",
|
||||
srcIP: "fd00::1",
|
||||
dstIP: "fd00::100",
|
||||
proto: fw.ProtocolICMP,
|
||||
ruleIP: "fd00::1",
|
||||
ruleProto: fw.ProtocolICMP,
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6: block TCP without rule",
|
||||
srcIP: "fd00::2",
|
||||
dstIP: "fd00::100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
ruleIP: "fd00::1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6: drop rule",
|
||||
srcIP: "fd00::1",
|
||||
dstIP: "fd00::100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 22,
|
||||
ruleIP: "fd00::1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{22}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6: allow all protocols",
|
||||
srcIP: "fd00::1",
|
||||
dstIP: "fd00::100",
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 9999,
|
||||
ruleIP: "fd00::1",
|
||||
ruleProto: fw.ProtocolALL,
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6: v4 wildcard ICMP rule matches ICMPv6 via protoLayerMatches",
|
||||
srcIP: "fd00::1",
|
||||
dstIP: "fd00::100",
|
||||
proto: fw.ProtocolICMP,
|
||||
ruleIP: "0.0.0.0",
|
||||
ruleProto: fw.ProtocolICMP,
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("IPv6 implicit DROP (no rules)", func(t *testing.T) {
|
||||
packet := createTestPacket(t, "fd00::1", "fd00::100", fw.ProtocolTCP, 12345, 443)
|
||||
isDropped := manager.FilterInbound(packet, 0)
|
||||
require.True(t, isDropped, "IPv6 packet should be dropped when no rules exist")
|
||||
})
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.ruleAction == fw.ActionDrop {
|
||||
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
for _, rule := range rules {
|
||||
require.NoError(t, manager.DeletePeerRule(rule))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction, "")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, rules)
|
||||
t.Cleanup(func() {
|
||||
for _, rule := range rules {
|
||||
require.NoError(t, manager.DeletePeerRule(rule))
|
||||
}
|
||||
})
|
||||
|
||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
isDropped := manager.FilterInbound(packet, 0)
|
||||
require.Equal(t, tc.shouldBeBlocked, isDropped, "packet filter result mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcPort, dstPort uint16) []byte {
|
||||
t.Helper()
|
||||
|
||||
src := net.ParseIP(srcIP)
|
||||
dst := net.ParseIP(dstIP)
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{
|
||||
ComputeChecksums: true,
|
||||
FixLengths: true,
|
||||
}
|
||||
|
||||
ipLayer := &layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
SrcIP: net.ParseIP(srcIP),
|
||||
DstIP: net.ParseIP(dstIP),
|
||||
}
|
||||
// Detect address family
|
||||
isV6 := src.To4() == nil
|
||||
|
||||
var err error
|
||||
switch proto {
|
||||
case fw.ProtocolTCP:
|
||||
ipLayer.Protocol = layers.IPProtocolTCP
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
}
|
||||
err = tcp.SetNetworkLayerForChecksum(ipLayer)
|
||||
require.NoError(t, err)
|
||||
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcp)
|
||||
|
||||
case fw.ProtocolUDP:
|
||||
ipLayer.Protocol = layers.IPProtocolUDP
|
||||
udp := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(srcPort),
|
||||
DstPort: layers.UDPPort(dstPort),
|
||||
if isV6 {
|
||||
ip6 := &layers.IPv6{
|
||||
Version: 6,
|
||||
HopLimit: 64,
|
||||
SrcIP: src,
|
||||
DstIP: dst,
|
||||
}
|
||||
err = udp.SetNetworkLayerForChecksum(ipLayer)
|
||||
require.NoError(t, err)
|
||||
err = gopacket.SerializeLayers(buf, opts, ipLayer, udp)
|
||||
|
||||
case fw.ProtocolICMP:
|
||||
ipLayer.Protocol = layers.IPProtocolICMPv4
|
||||
icmp := &layers.ICMPv4{
|
||||
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0),
|
||||
switch proto {
|
||||
case fw.ProtocolTCP:
|
||||
ip6.NextHeader = layers.IPProtocolTCP
|
||||
tcp := &layers.TCP{SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort)}
|
||||
_ = tcp.SetNetworkLayerForChecksum(ip6)
|
||||
err = gopacket.SerializeLayers(buf, opts, ip6, tcp)
|
||||
case fw.ProtocolUDP:
|
||||
ip6.NextHeader = layers.IPProtocolUDP
|
||||
udp := &layers.UDP{SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort)}
|
||||
_ = udp.SetNetworkLayerForChecksum(ip6)
|
||||
err = gopacket.SerializeLayers(buf, opts, ip6, udp)
|
||||
case fw.ProtocolICMP:
|
||||
ip6.NextHeader = layers.IPProtocolICMPv6
|
||||
icmp := &layers.ICMPv6{
|
||||
TypeCode: layers.CreateICMPv6TypeCode(layers.ICMPv6TypeEchoRequest, 0),
|
||||
}
|
||||
_ = icmp.SetNetworkLayerForChecksum(ip6)
|
||||
err = gopacket.SerializeLayers(buf, opts, ip6, icmp)
|
||||
default:
|
||||
err = gopacket.SerializeLayers(buf, opts, ip6)
|
||||
}
|
||||
} else {
|
||||
ip4 := &layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
SrcIP: src,
|
||||
DstIP: dst,
|
||||
}
|
||||
err = gopacket.SerializeLayers(buf, opts, ipLayer, icmp)
|
||||
|
||||
default:
|
||||
err = gopacket.SerializeLayers(buf, opts, ipLayer)
|
||||
switch proto {
|
||||
case fw.ProtocolTCP:
|
||||
ip4.Protocol = layers.IPProtocolTCP
|
||||
tcp := &layers.TCP{SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort)}
|
||||
_ = tcp.SetNetworkLayerForChecksum(ip4)
|
||||
err = gopacket.SerializeLayers(buf, opts, ip4, tcp)
|
||||
case fw.ProtocolUDP:
|
||||
ip4.Protocol = layers.IPProtocolUDP
|
||||
udp := &layers.UDP{SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort)}
|
||||
_ = udp.SetNetworkLayerForChecksum(ip4)
|
||||
err = gopacket.SerializeLayers(buf, opts, ip4, udp)
|
||||
case fw.ProtocolICMP:
|
||||
ip4.Protocol = layers.IPProtocolICMPv4
|
||||
icmp := &layers.ICMPv4{TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0)}
|
||||
err = gopacket.SerializeLayers(buf, opts, ip4, icmp)
|
||||
default:
|
||||
err = gopacket.SerializeLayers(buf, opts, ip4)
|
||||
}
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
@@ -1498,3 +1681,103 @@ func TestRouteACLSet(t *testing.T) {
|
||||
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
|
||||
}
|
||||
|
||||
// TestRouteACLFilteringIPv6 tests IPv6 route ACL matching directly via routeACLsPass.
|
||||
// Note: full FilterInbound for routed IPv6 traffic drops at the forwarder stage (IPv4-only)
|
||||
// but the ACL decision itself is correct.
|
||||
func TestRouteACLFilteringIPv6(t *testing.T) {
|
||||
manager := setupRoutedManager(t, "10.10.0.100/16")
|
||||
|
||||
v6Dst := netip.MustParsePrefix("fd00:dead:beef::/48")
|
||||
_, err := manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
|
||||
fw.Network{Prefix: v6Dst},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("fd00:dead:beef:1::/64")},
|
||||
fw.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionDrop,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
srcIP netip.Addr
|
||||
dstIP netip.Addr
|
||||
proto gopacket.LayerType
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
allowed bool
|
||||
}{
|
||||
{
|
||||
name: "IPv6 TCP to allowed dest",
|
||||
srcIP: netip.MustParseAddr("fd00::1"),
|
||||
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
||||
proto: layers.LayerTypeTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
allowed: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 TCP wrong port",
|
||||
srcIP: netip.MustParseAddr("fd00::1"),
|
||||
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
||||
proto: layers.LayerTypeTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
allowed: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 UDP not matched by TCP rule",
|
||||
srcIP: netip.MustParseAddr("fd00::1"),
|
||||
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
||||
proto: layers.LayerTypeUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
allowed: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 ICMPv6 matches ICMP rule via protoLayerMatches",
|
||||
srcIP: netip.MustParseAddr("fd00::1"),
|
||||
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
||||
proto: layers.LayerTypeICMPv6,
|
||||
allowed: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 to denied subnet",
|
||||
srcIP: netip.MustParseAddr("fd00::1"),
|
||||
dstIP: netip.MustParseAddr("fd00:dead:beef:1::1"),
|
||||
proto: layers.LayerTypeTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
allowed: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 source outside allowed range",
|
||||
srcIP: netip.MustParseAddr("fe80::1"),
|
||||
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
||||
proto: layers.LayerTypeTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
allowed: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, pass := manager.routeACLsPass(tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
require.Equal(t, tc.allowed, pass, "route ACL result mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -582,11 +582,16 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
d.parser4.IgnoreUnsupported = true
|
||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv6,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser6.IgnoreUnsupported = true
|
||||
return d
|
||||
},
|
||||
}
|
||||
@@ -687,11 +692,16 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
d.parser4.IgnoreUnsupported = true
|
||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv6,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser6.IgnoreUnsupported = true
|
||||
return d
|
||||
},
|
||||
}
|
||||
@@ -1097,8 +1107,8 @@ func TestMSSClamping(t *testing.T) {
|
||||
}()
|
||||
|
||||
require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default")
|
||||
expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize)
|
||||
require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40")
|
||||
require.Equal(t, uint16(1280-ipv4TCPHeaderMinSize), manager.mssClampValueIPv4, "IPv4 MSS clamp value should be MTU - 40")
|
||||
require.Equal(t, uint16(1280-ipv6TCPHeaderMinSize), manager.mssClampValueIPv6, "IPv6 MSS clamp value should be MTU - 60")
|
||||
|
||||
err = manager.UpdateLocalIPs()
|
||||
require.NoError(t, err)
|
||||
@@ -1116,7 +1126,7 @@ func TestMSSClamping(t *testing.T) {
|
||||
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
||||
require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType))
|
||||
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
||||
require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40")
|
||||
require.Equal(t, manager.mssClampValueIPv4, actualMSS, "MSS should be clamped to MTU - 40")
|
||||
})
|
||||
|
||||
t.Run("SYN packet with low MSS unchanged", func(t *testing.T) {
|
||||
@@ -1140,7 +1150,7 @@ func TestMSSClamping(t *testing.T) {
|
||||
d := parsePacket(t, packet)
|
||||
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
||||
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
||||
require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped")
|
||||
require.Equal(t, manager.mssClampValueIPv4, actualMSS, "MSS in SYN-ACK should be clamped")
|
||||
})
|
||||
|
||||
t.Run("Non-SYN packet unchanged", func(t *testing.T) {
|
||||
@@ -1312,13 +1322,18 @@ func TestShouldForward(t *testing.T) {
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
d.parser4.IgnoreUnsupported = true
|
||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv6,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser6.IgnoreUnsupported = true
|
||||
|
||||
err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded)
|
||||
err = d.decodePacket(buf.Bytes())
|
||||
require.NoError(t, err)
|
||||
|
||||
return d
|
||||
@@ -1378,6 +1393,44 @@ func TestShouldForward(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// Add IPv6 to the interface and test dual-stack cases
|
||||
wgIPv6 := netip.MustParseAddr("fd00::1")
|
||||
otherIPv6 := netip.MustParseAddr("fd00::2")
|
||||
ifaceMock.AddressFunc = func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: wgIP,
|
||||
Network: netip.PrefixFrom(wgIP, 24),
|
||||
IPv6: wgIPv6,
|
||||
IPv6Net: netip.PrefixFrom(wgIPv6, 64),
|
||||
}
|
||||
}
|
||||
|
||||
// Re-create manager to pick up the new address with IPv6
|
||||
require.NoError(t, manager.Close(nil))
|
||||
manager, err = Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
v6Cases := []struct {
|
||||
name string
|
||||
dstIP netip.Addr
|
||||
expected bool
|
||||
description string
|
||||
}{
|
||||
{"v6 traffic to other address", otherIPv6, true, "should forward v6 traffic not destined to our v6 address"},
|
||||
{"v6 traffic to our v6 IP", wgIPv6, false, "should not forward traffic destined to our v6 address"},
|
||||
{"v4 traffic to other with v6 configured", otherIP, true, "should forward v4 traffic when v6 configured"},
|
||||
{"v4 traffic to our v4 IP with v6 configured", wgIP, false, "should not forward traffic to our v4 address"},
|
||||
}
|
||||
for _, tt := range v6Cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager.localForwarding = true
|
||||
manager.netstack = false
|
||||
decoder := createTCPDecoder(8080)
|
||||
result := manager.shouldForward(decoder, tt.dstIP)
|
||||
require.Equal(t, tt.expected, result, tt.description)
|
||||
})
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Configure manager
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
@@ -47,17 +48,23 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
|
||||
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
||||
var written int
|
||||
for _, pkt := range pkts.AsSlice() {
|
||||
netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice())
|
||||
|
||||
data := stack.PayloadSince(pkt.NetworkHeader())
|
||||
if data == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Send the packet through WireGuard
|
||||
address := netHeader.DestinationAddress()
|
||||
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
|
||||
if err != nil {
|
||||
raw := pkt.NetworkHeader().View().AsSlice()
|
||||
if len(raw) == 0 {
|
||||
continue
|
||||
}
|
||||
var address tcpip.Address
|
||||
if raw[0]>>4 == 6 {
|
||||
address = header.IPv6(raw).DestinationAddress()
|
||||
} else {
|
||||
address = header.IPv4(raw).DestinationAddress()
|
||||
}
|
||||
|
||||
if err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice()); err != nil {
|
||||
e.logger.Error1("CreateOutboundPacket: %v", err)
|
||||
continue
|
||||
}
|
||||
@@ -103,5 +110,7 @@ type epID stack.TransportEndpointID
|
||||
|
||||
func (i epID) String() string {
|
||||
// src and remote is swapped
|
||||
return fmt.Sprintf("%s:%d → %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
||||
return net.JoinHostPort(i.RemoteAddress.String(), strconv.Itoa(int(i.RemotePort))) +
|
||||
" → " +
|
||||
net.JoinHostPort(i.LocalAddress.String(), strconv.Itoa(int(i.LocalPort)))
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
@@ -36,25 +37,31 @@ type Forwarder struct {
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
// ruleIdMap is used to store the rule ID for a given connection
|
||||
ruleIdMap sync.Map
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip tcpip.Address
|
||||
netstack bool
|
||||
hasRawICMPAccess bool
|
||||
pingSemaphore chan struct{}
|
||||
ruleIdMap sync.Map
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip tcpip.Address
|
||||
ipv6 tcpip.Address
|
||||
netstack bool
|
||||
hasRawICMPAccess bool
|
||||
hasRawICMPv6Access bool
|
||||
pingSemaphore chan struct{}
|
||||
}
|
||||
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{
|
||||
ipv4.NewProtocol,
|
||||
ipv6.NewProtocol,
|
||||
},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{
|
||||
tcp.NewProtocol,
|
||||
udp.NewProtocol,
|
||||
icmp.NewProtocol4,
|
||||
icmp.NewProtocol6,
|
||||
},
|
||||
HandleLocal: false,
|
||||
})
|
||||
@@ -73,7 +80,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
Address: tcpip.AddrFrom4(iface.Address().IP.As4()),
|
||||
PrefixLen: iface.Address().Network.Bits(),
|
||||
},
|
||||
}
|
||||
@@ -82,6 +89,19 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
return nil, fmt.Errorf("failed to add protocol address: %s", err)
|
||||
}
|
||||
|
||||
if v6 := iface.Address().IPv6; v6.IsValid() {
|
||||
v6Addr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv6.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFrom16(v6.As16()),
|
||||
PrefixLen: iface.Address().IPv6Net.Bits(),
|
||||
},
|
||||
}
|
||||
if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil {
|
||||
return nil, fmt.Errorf("add IPv6 protocol address: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
defaultSubnet, err := tcpip.NewSubnet(
|
||||
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
|
||||
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
|
||||
@@ -90,6 +110,14 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
return nil, fmt.Errorf("creating default subnet: %w", err)
|
||||
}
|
||||
|
||||
defaultSubnetV6, err := tcpip.NewSubnet(
|
||||
tcpip.AddrFrom16([16]byte{}),
|
||||
tcpip.MaskFromBytes(make([]byte, 16)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating default v6 subnet: %w", err)
|
||||
}
|
||||
|
||||
if err := s.SetPromiscuousMode(nicID, true); err != nil {
|
||||
return nil, fmt.Errorf("set promiscuous mode: %s", err)
|
||||
}
|
||||
@@ -98,10 +126,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
}
|
||||
|
||||
s.SetRouteTable([]tcpip.Route{
|
||||
{
|
||||
Destination: defaultSubnet,
|
||||
NIC: nicID,
|
||||
},
|
||||
{Destination: defaultSubnet, NIC: nicID},
|
||||
{Destination: defaultSubnetV6, NIC: nicID},
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -114,7 +140,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
netstack: netstack,
|
||||
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
ip: tcpip.AddrFrom4(iface.Address().IP.As4()),
|
||||
ipv6: addrFromNetipAddr(iface.Address().IPv6),
|
||||
pingSemaphore: make(chan struct{}, 3),
|
||||
}
|
||||
|
||||
@@ -131,7 +158,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
udpForwarder := udp.NewForwarder(s, f.handleUDP)
|
||||
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||
|
||||
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
|
||||
// ICMP is handled directly in InjectIncomingPacket, bypassing gVisor's
|
||||
// network layer. This avoids duplicate echo replies (v4) and the v6
|
||||
// auto-reply bug where gVisor responds at the network layer before
|
||||
// our transport handler fires.
|
||||
|
||||
f.checkICMPCapability()
|
||||
|
||||
@@ -140,8 +170,30 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
}
|
||||
|
||||
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
||||
if len(payload) < header.IPv4MinimumSize {
|
||||
return fmt.Errorf("packet too small: %d bytes", len(payload))
|
||||
if len(payload) == 0 {
|
||||
return fmt.Errorf("empty packet")
|
||||
}
|
||||
|
||||
var protoNum tcpip.NetworkProtocolNumber
|
||||
switch payload[0] >> 4 {
|
||||
case 4:
|
||||
if len(payload) < header.IPv4MinimumSize {
|
||||
return fmt.Errorf("IPv4 packet too small: %d bytes", len(payload))
|
||||
}
|
||||
if f.handleICMPDirect(payload) {
|
||||
return nil
|
||||
}
|
||||
protoNum = ipv4.ProtocolNumber
|
||||
case 6:
|
||||
if len(payload) < header.IPv6MinimumSize {
|
||||
return fmt.Errorf("IPv6 packet too small: %d bytes", len(payload))
|
||||
}
|
||||
if f.handleICMPDirect(payload) {
|
||||
return nil
|
||||
}
|
||||
protoNum = ipv6.ProtocolNumber
|
||||
default:
|
||||
return fmt.Errorf("unknown IP version: %d", payload[0]>>4)
|
||||
}
|
||||
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
@@ -150,11 +202,95 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
||||
defer pkt.DecRef()
|
||||
|
||||
if f.endpoint.dispatcher != nil {
|
||||
f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
|
||||
f.endpoint.dispatcher.DeliverNetworkPacket(protoNum, pkt)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleICMPDirect intercepts ICMP packets from raw IP payloads before they
|
||||
// enter gVisor. It synthesizes the TransportEndpointID and PacketBuffer that
|
||||
// the existing handlers expect, then dispatches to handleICMP/handleICMPv6.
|
||||
// This bypasses gVisor's network layer which causes duplicate v4 echo replies
|
||||
// and auto-replies to all v6 echo requests in promiscuous mode.
|
||||
//
|
||||
// Unlike gVisor's network layer, this does not validate ICMP checksums or
|
||||
// reassemble IP fragments. Fragmented ICMP packets fall through to gVisor.
|
||||
func parseICMPv4(payload []byte) (ipHdrLen int, src, dst tcpip.Address, ok bool) {
|
||||
ip := header.IPv4(payload)
|
||||
if ip.Protocol() != uint8(header.ICMPv4ProtocolNumber) {
|
||||
return 0, src, dst, false
|
||||
}
|
||||
if ip.FragmentOffset() != 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 {
|
||||
return 0, src, dst, false
|
||||
}
|
||||
ipHdrLen = int(ip.HeaderLength())
|
||||
if len(payload)-ipHdrLen < header.ICMPv4MinimumSize {
|
||||
return 0, src, dst, false
|
||||
}
|
||||
return ipHdrLen, ip.SourceAddress(), ip.DestinationAddress(), true
|
||||
}
|
||||
|
||||
func parseICMPv6(payload []byte) (ipHdrLen int, src, dst tcpip.Address, ok bool) {
|
||||
ip := header.IPv6(payload)
|
||||
if ip.NextHeader() != uint8(header.ICMPv6ProtocolNumber) {
|
||||
return 0, src, dst, false
|
||||
}
|
||||
ipHdrLen = header.IPv6MinimumSize
|
||||
if len(payload)-ipHdrLen < header.ICMPv6MinimumSize {
|
||||
return 0, src, dst, false
|
||||
}
|
||||
return ipHdrLen, ip.SourceAddress(), ip.DestinationAddress(), true
|
||||
}
|
||||
|
||||
func (f *Forwarder) handleICMPDirect(payload []byte) bool {
|
||||
var (
|
||||
ipHdrLen int
|
||||
srcAddr tcpip.Address
|
||||
dstAddr tcpip.Address
|
||||
ok bool
|
||||
)
|
||||
switch payload[0] >> 4 {
|
||||
case 4:
|
||||
ipHdrLen, srcAddr, dstAddr, ok = parseICMPv4(payload)
|
||||
case 6:
|
||||
ipHdrLen, srcAddr, dstAddr, ok = parseICMPv6(payload)
|
||||
}
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// Let gVisor handle ICMP destined for our own addresses natively.
|
||||
// Its network-layer auto-reply is correct and efficient for local traffic.
|
||||
if f.ip.Equal(dstAddr) || f.ipv6.Equal(dstAddr) {
|
||||
return false
|
||||
}
|
||||
|
||||
id := stack.TransportEndpointID{
|
||||
LocalAddress: dstAddr,
|
||||
RemoteAddress: srcAddr,
|
||||
}
|
||||
|
||||
// Build a PacketBuffer with headers consumed the same way gVisor would.
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(payload),
|
||||
})
|
||||
defer pkt.DecRef()
|
||||
|
||||
if _, ok := pkt.NetworkHeader().Consume(ipHdrLen); !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
icmpPayload := payload[ipHdrLen:]
|
||||
if _, ok := pkt.TransportHeader().Consume(len(icmpPayload)); !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if payload[0]>>4 == 6 {
|
||||
return f.handleICMPv6(id, pkt)
|
||||
}
|
||||
return f.handleICMP(id, pkt)
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the forwarder
|
||||
func (f *Forwarder) Stop() {
|
||||
f.cancel()
|
||||
@@ -167,11 +303,14 @@ func (f *Forwarder) Stop() {
|
||||
f.stack.Wait()
|
||||
}
|
||||
|
||||
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||
func (f *Forwarder) determineDialAddr(addr tcpip.Address) netip.Addr {
|
||||
if f.netstack && f.ip.Equal(addr) {
|
||||
return net.IPv4(127, 0, 0, 1)
|
||||
return netip.AddrFrom4([4]byte{127, 0, 0, 1})
|
||||
}
|
||||
return addr.AsSlice()
|
||||
if f.netstack && f.ipv6.Equal(addr) {
|
||||
return netip.IPv6Loopback()
|
||||
}
|
||||
return addrToNetipAddr(addr)
|
||||
}
|
||||
|
||||
func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) {
|
||||
@@ -205,23 +344,50 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe
|
||||
}
|
||||
}
|
||||
|
||||
// addrFromNetipAddr converts a netip.Addr to a gvisor tcpip.Address without allocating.
|
||||
func addrFromNetipAddr(addr netip.Addr) tcpip.Address {
|
||||
if !addr.IsValid() {
|
||||
return tcpip.Address{}
|
||||
}
|
||||
if addr.Is4() {
|
||||
return tcpip.AddrFrom4(addr.As4())
|
||||
}
|
||||
return tcpip.AddrFrom16(addr.As16())
|
||||
}
|
||||
|
||||
// addrToNetipAddr converts a gvisor tcpip.Address to netip.Addr without allocating.
|
||||
func addrToNetipAddr(addr tcpip.Address) netip.Addr {
|
||||
switch addr.Len() {
|
||||
case 4:
|
||||
return netip.AddrFrom4(addr.As4())
|
||||
case 16:
|
||||
return netip.AddrFrom16(addr.As16())
|
||||
default:
|
||||
return netip.Addr{}
|
||||
}
|
||||
}
|
||||
|
||||
// checkICMPCapability tests whether we have raw ICMP socket access at startup.
|
||||
func (f *Forwarder) checkICMPCapability() {
|
||||
f.hasRawICMPAccess = probeRawICMP("ip4:icmp", "0.0.0.0", f.logger)
|
||||
f.hasRawICMPv6Access = probeRawICMP("ip6:ipv6-icmp", "::", f.logger)
|
||||
}
|
||||
|
||||
func probeRawICMP(network, addr string, logger *nblog.Logger) bool {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
lc := net.ListenConfig{}
|
||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||
conn, err := lc.ListenPacket(ctx, network, addr)
|
||||
if err != nil {
|
||||
f.hasRawICMPAccess = false
|
||||
f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback")
|
||||
return
|
||||
logger.Debug1("forwarder: no raw %s socket access, will use ping binary fallback", network)
|
||||
return false
|
||||
}
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err)
|
||||
logger.Debug2("forwarder: failed to close %s capability test socket: %v", network, err)
|
||||
}
|
||||
|
||||
f.hasRawICMPAccess = true
|
||||
f.logger.Debug("forwarder: Raw ICMP socket access available")
|
||||
logger.Debug1("forwarder: raw %s socket access available", network)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBu
|
||||
}
|
||||
|
||||
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond)
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), false, 100*time.Millisecond)
|
||||
if err != nil {
|
||||
f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err)
|
||||
return true
|
||||
@@ -72,18 +72,23 @@ func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointI
|
||||
|
||||
// forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection.
|
||||
// The caller is responsible for closing the returned connection.
|
||||
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) {
|
||||
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, v6 bool, timeout time.Duration) (net.PacketConn, error) {
|
||||
ctx, cancel := context.WithTimeout(f.ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
network, listenAddr := "ip4:icmp", "0.0.0.0"
|
||||
if v6 {
|
||||
network, listenAddr = "ip6:ipv6-icmp", "::"
|
||||
}
|
||||
|
||||
lc := net.ListenConfig{}
|
||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||
conn, err := lc.ListenPacket(ctx, network, listenAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create ICMP socket: %w", err)
|
||||
}
|
||||
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
dst := &net.IPAddr{IP: dstIP}
|
||||
dst := &net.IPAddr{IP: dstIP.AsSlice()}
|
||||
|
||||
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
@@ -102,7 +107,7 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by
|
||||
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
|
||||
sendTime := time.Now()
|
||||
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second)
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, false, 5*time.Second)
|
||||
if err != nil {
|
||||
f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err)
|
||||
return
|
||||
@@ -150,19 +155,23 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
|
||||
txPackets = 1
|
||||
}
|
||||
|
||||
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||
srcIp := addrToNetipAddr(id.RemoteAddress)
|
||||
dstIp := addrToNetipAddr(id.LocalAddress)
|
||||
|
||||
proto := nftypes.ICMP
|
||||
if srcIp.Is6() {
|
||||
proto = nftypes.ICMPv6
|
||||
}
|
||||
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.ICMP,
|
||||
// TODO: handle ipv6
|
||||
SourceIP: srcIp,
|
||||
DestIP: dstIp,
|
||||
ICMPType: icmpType,
|
||||
ICMPCode: icmpCode,
|
||||
Protocol: proto,
|
||||
SourceIP: srcIp,
|
||||
DestIP: dstIp,
|
||||
ICMPType: icmpType,
|
||||
ICMPCode: icmpCode,
|
||||
|
||||
RxBytes: rxBytes,
|
||||
TxBytes: txBytes,
|
||||
@@ -209,26 +218,159 @@ func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpoi
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
// handleICMPv6 handles ICMPv6 packets from the network stack.
|
||||
func (f *Forwarder) handleICMPv6(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
||||
icmpHdr := header.ICMPv6(pkt.TransportHeader().View().AsSlice())
|
||||
|
||||
flowID := uuid.New()
|
||||
f.sendICMPEvent(nftypes.TypeStart, flowID, id, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 0, 0)
|
||||
|
||||
if icmpHdr.Type() == header.ICMPv6EchoRequest {
|
||||
return f.handleICMPv6Echo(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()))
|
||||
}
|
||||
|
||||
// For non-echo types (Destination Unreachable, Packet Too Big, etc), forward without waiting
|
||||
if !f.hasRawICMPv6Access {
|
||||
f.logger.Debug2("forwarder: Cannot handle ICMPv6 type %v without raw socket access for %v", icmpHdr.Type(), epID(id))
|
||||
return false
|
||||
}
|
||||
|
||||
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), true, 100*time.Millisecond)
|
||||
if err != nil {
|
||||
f.logger.Error2("forwarder: Failed to forward ICMPv6 packet for %v: %v", epID(id), err)
|
||||
return true
|
||||
}
|
||||
if err := conn.Close(); err != nil {
|
||||
f.logger.Debug1("forwarder: Failed to close ICMPv6 socket: %v", err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// handleICMPv6Echo handles ICMPv6 echo requests using the ping binary.
|
||||
func (f *Forwarder) handleICMPv6Echo(flowID uuid.UUID, id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType, icmpCode uint8) bool {
|
||||
select {
|
||||
case f.pingSemaphore <- struct{}{}:
|
||||
icmpData := stack.PayloadSince(pkt.TransportHeader()).ToSlice()
|
||||
rxBytes := pkt.Size()
|
||||
|
||||
go func() {
|
||||
defer func() { <-f.pingSemaphore }()
|
||||
f.handleICMPv6ViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
||||
}()
|
||||
default:
|
||||
f.logger.Debug3("forwarder: ICMPv6 rate limit exceeded for %v type %v code %v", epID(id), icmpType, icmpCode)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// handleICMPv6ViaPing uses the system ping6 binary for ICMPv6 echo.
|
||||
func (f *Forwarder) handleICMPv6ViaPing(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
|
||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
cmd := buildPingCommand(ctx, dstIP, 5*time.Second)
|
||||
|
||||
pingStart := time.Now()
|
||||
if err := cmd.Run(); err != nil {
|
||||
f.logger.Warn4("forwarder: Ping6 failed for %v type %v code %v: %v", epID(id), icmpType, icmpCode, err)
|
||||
return
|
||||
}
|
||||
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMPv6 echo request %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
|
||||
txBytes := f.synthesizeICMPv6EchoReply(id, icmpData)
|
||||
|
||||
f.logger.Trace4("forwarder: Forwarded ICMPv6 echo reply %v type %v code %v (rtt=%v, ping binary)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
// synthesizeICMPv6EchoReply creates an ICMPv6 echo reply and injects it back.
|
||||
func (f *Forwarder) synthesizeICMPv6EchoReply(id stack.TransportEndpointID, icmpData []byte) int {
|
||||
replyICMP := make([]byte, len(icmpData))
|
||||
copy(replyICMP, icmpData)
|
||||
|
||||
replyHdr := header.ICMPv6(replyICMP)
|
||||
replyHdr.SetType(header.ICMPv6EchoReply)
|
||||
replyHdr.SetChecksum(0)
|
||||
// ICMPv6Checksum computes the pseudo-header internally from Src/Dst.
|
||||
// Header contains the full ICMP message, so PayloadCsum/PayloadLen are zero.
|
||||
replyHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||
Header: replyHdr,
|
||||
Src: id.LocalAddress,
|
||||
Dst: id.RemoteAddress,
|
||||
}))
|
||||
|
||||
return f.injectICMPv6Reply(id, replyICMP)
|
||||
}
|
||||
|
||||
// injectICMPv6Reply wraps an ICMPv6 payload in an IPv6 header and sends to the peer.
|
||||
func (f *Forwarder) injectICMPv6Reply(id stack.TransportEndpointID, icmpPayload []byte) int {
|
||||
ipHdr := make([]byte, header.IPv6MinimumSize)
|
||||
ip := header.IPv6(ipHdr)
|
||||
ip.Encode(&header.IPv6Fields{
|
||||
PayloadLength: uint16(len(icmpPayload)),
|
||||
TransportProtocol: header.ICMPv6ProtocolNumber,
|
||||
HopLimit: 64,
|
||||
SrcAddr: id.LocalAddress,
|
||||
DstAddr: id.RemoteAddress,
|
||||
})
|
||||
|
||||
fullPacket := make([]byte, 0, len(ipHdr)+len(icmpPayload))
|
||||
fullPacket = append(fullPacket, ipHdr...)
|
||||
fullPacket = append(fullPacket, icmpPayload...)
|
||||
|
||||
if err := f.endpoint.device.CreateOutboundPacket(fullPacket, id.RemoteAddress.AsSlice()); err != nil {
|
||||
f.logger.Error1("forwarder: Failed to send ICMPv6 reply to peer: %v", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
return len(fullPacket)
|
||||
}
|
||||
|
||||
const (
|
||||
pingBin = "ping"
|
||||
ping6Bin = "ping6"
|
||||
)
|
||||
|
||||
// buildPingCommand creates a platform-specific ping command.
|
||||
func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd {
|
||||
// Most platforms auto-detect IPv6 from raw addresses. macOS/iOS/OpenBSD require ping6.
|
||||
func buildPingCommand(ctx context.Context, target netip.Addr, timeout time.Duration) *exec.Cmd {
|
||||
timeoutSec := int(timeout.Seconds())
|
||||
if timeoutSec < 1 {
|
||||
timeoutSec = 1
|
||||
}
|
||||
|
||||
isV6 := target.Is6()
|
||||
timeoutStr := fmt.Sprintf("%d", timeoutSec)
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "linux", "android":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-W", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
|
||||
return exec.CommandContext(ctx, pingBin, "-c", "1", "-W", timeoutStr, "-q", target.String())
|
||||
case "darwin", "ios":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
|
||||
bin := pingBin
|
||||
if isV6 {
|
||||
bin = ping6Bin
|
||||
}
|
||||
return exec.CommandContext(ctx, bin, "-c", "1", "-t", timeoutStr, "-q", target.String())
|
||||
case "freebsd":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), target.String())
|
||||
return exec.CommandContext(ctx, pingBin, "-c", "1", "-t", timeoutStr, target.String())
|
||||
case "openbsd", "netbsd":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-w", fmt.Sprintf("%d", timeoutSec), target.String())
|
||||
bin := pingBin
|
||||
if isV6 {
|
||||
bin = ping6Bin
|
||||
}
|
||||
return exec.CommandContext(ctx, bin, "-c", "1", "-w", timeoutStr, target.String())
|
||||
case "windows":
|
||||
return exec.CommandContext(ctx, "ping", "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
|
||||
return exec.CommandContext(ctx, pingBin, "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
|
||||
default:
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", target.String())
|
||||
return exec.CommandContext(ctx, pingBin, "-c", "1", target.String())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,10 +2,9 @@ package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -33,7 +32,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
}
|
||||
}()
|
||||
|
||||
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||
dialAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort)))
|
||||
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||
if err != nil {
|
||||
@@ -133,15 +132,14 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
||||
}
|
||||
|
||||
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
||||
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||
srcIp := addrToNetipAddr(id.RemoteAddress)
|
||||
dstIp := addrToNetipAddr(id.LocalAddress)
|
||||
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.TCP,
|
||||
// TODO: handle ipv6
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.TCP,
|
||||
SourceIP: srcIp,
|
||||
DestIP: dstIp,
|
||||
SourcePort: id.RemotePort,
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -158,7 +158,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
}
|
||||
}()
|
||||
|
||||
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||
dstAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort)))
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||
if err != nil {
|
||||
f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err)
|
||||
@@ -276,15 +276,14 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
||||
|
||||
// sendUDPEvent stores flow events for UDP connections
|
||||
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
||||
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||
srcIp := addrToNetipAddr(id.RemoteAddress)
|
||||
dstIp := addrToNetipAddr(id.LocalAddress)
|
||||
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.UDP,
|
||||
// TODO: handle ipv6
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.UDP,
|
||||
SourceIP: srcIp,
|
||||
DestIP: dstIp,
|
||||
SourcePort: id.RemotePort,
|
||||
|
||||
@@ -4,89 +4,32 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
)
|
||||
|
||||
type localIPManager struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// fixed-size high array for upper byte of a IPv4 address
|
||||
ipv4Bitmap [256]*ipv4LowBitmap
|
||||
// localIPSnapshot is an immutable snapshot of local IP addresses, swapped
|
||||
// atomically so reads are lock-free.
|
||||
type localIPSnapshot struct {
|
||||
ips map[netip.Addr]struct{}
|
||||
}
|
||||
|
||||
// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address
|
||||
type ipv4LowBitmap struct {
|
||||
bitmap [8192]uint32
|
||||
type localIPManager struct {
|
||||
snapshot atomic.Pointer[localIPSnapshot]
|
||||
}
|
||||
|
||||
func newLocalIPManager() *localIPManager {
|
||||
return &localIPManager{}
|
||||
m := &localIPManager{}
|
||||
m.snapshot.Store(&localIPSnapshot{
|
||||
ips: make(map[netip.Addr]struct{}),
|
||||
})
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *localIPManager) setBitmapBit(ip net.IP) {
|
||||
ipv4 := ip.To4()
|
||||
if ipv4 == nil {
|
||||
return
|
||||
}
|
||||
high := uint16(ipv4[0])
|
||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
|
||||
if m.ipv4Bitmap[high] == nil {
|
||||
m.ipv4Bitmap[high] = &ipv4LowBitmap{}
|
||||
}
|
||||
|
||||
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
||||
}
|
||||
|
||||
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||
if !ip.Is4() {
|
||||
return
|
||||
}
|
||||
ipv4 := ip.AsSlice()
|
||||
|
||||
high := uint16(ipv4[0])
|
||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||
|
||||
if bitmap[high] == nil {
|
||||
bitmap[high] = &ipv4LowBitmap{}
|
||||
}
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
bitmap[high].bitmap[index] |= 1 << bit
|
||||
|
||||
if _, exists := ipv4Set[ip]; !exists {
|
||||
ipv4Set[ip] = struct{}{}
|
||||
*ipv4Addresses = append(*ipv4Addresses, ip)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||
high := uint16(ip[0])
|
||||
low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3])
|
||||
|
||||
if m.ipv4Bitmap[high] == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
||||
}
|
||||
|
||||
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
|
||||
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||
func processInterface(iface net.Interface, ips map[netip.Addr]struct{}, addresses *[]netip.Addr) {
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||
@@ -104,18 +47,19 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
|
||||
continue
|
||||
}
|
||||
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
parsed, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||
log.Debugf("process IP failed: %v", err)
|
||||
}
|
||||
parsed = parsed.Unmap()
|
||||
ips[parsed] = struct{}{}
|
||||
*addresses = append(*addresses, parsed)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateLocalIPs rebuilds the local IP snapshot and swaps it in atomically.
|
||||
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
@@ -123,20 +67,20 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
}
|
||||
}()
|
||||
|
||||
var newIPv4Bitmap [256]*ipv4LowBitmap
|
||||
ipv4Set := make(map[netip.Addr]struct{})
|
||||
var ipv4Addresses []netip.Addr
|
||||
ips := make(map[netip.Addr]struct{})
|
||||
var addresses []netip.Addr
|
||||
|
||||
// 127.0.0.0/8
|
||||
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
||||
for i := 0; i < 8192; i++ {
|
||||
// #nosec G602 -- bitmap is defined as [8192]uint32, loop range is correct
|
||||
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
|
||||
}
|
||||
// loopback
|
||||
ips[netip.AddrFrom4([4]byte{127, 0, 0, 1})] = struct{}{}
|
||||
ips[netip.IPv6Loopback()] = struct{}{}
|
||||
|
||||
if iface != nil {
|
||||
if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil {
|
||||
return err
|
||||
ip := iface.Address().IP
|
||||
ips[ip] = struct{}{}
|
||||
addresses = append(addresses, ip)
|
||||
if v6 := iface.Address().IPv6; v6.IsValid() {
|
||||
ips[v6] = struct{}{}
|
||||
addresses = append(addresses, v6)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,25 +89,24 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
log.Warnf("failed to get interfaces: %v", err)
|
||||
} else {
|
||||
for _, intf := range interfaces {
|
||||
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
|
||||
processInterface(intf, ips, &addresses)
|
||||
}
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.ipv4Bitmap = newIPv4Bitmap
|
||||
m.mu.Unlock()
|
||||
m.snapshot.Store(&localIPSnapshot{ips: ips})
|
||||
|
||||
log.Debugf("Local IPv4 addresses: %v", ipv4Addresses)
|
||||
log.Debugf("Local IP addresses: %v", addresses)
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsLocalIP checks if the given IP is a local address. Lock-free on the read path.
|
||||
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
|
||||
if !ip.Is4() {
|
||||
return false
|
||||
s := m.snapshot.Load()
|
||||
|
||||
if ip.Is4() && ip.As4()[0] == 127 {
|
||||
return true
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
return m.checkBitmapBit(ip.AsSlice())
|
||||
_, found := s.ips[ip]
|
||||
return found
|
||||
}
|
||||
|
||||
72
client/firewall/uspfilter/localip_bench_test.go
Normal file
72
client/firewall/uspfilter/localip_bench_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
func setupManager(b *testing.B) *localIPManager {
|
||||
b.Helper()
|
||||
m := newLocalIPManager()
|
||||
mock := &IFaceMock{
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||
IPv6: netip.MustParseAddr("fd00::1"),
|
||||
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
||||
}
|
||||
},
|
||||
}
|
||||
if err := m.UpdateLocalIPs(mock); err != nil {
|
||||
b.Fatalf("UpdateLocalIPs: %v", err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func BenchmarkIsLocalIP_v4_hit(b *testing.B) {
|
||||
m := setupManager(b)
|
||||
ip := netip.MustParseAddr("100.64.0.1")
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.IsLocalIP(ip)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIsLocalIP_v4_miss(b *testing.B) {
|
||||
m := setupManager(b)
|
||||
ip := netip.MustParseAddr("8.8.8.8")
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.IsLocalIP(ip)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIsLocalIP_v6_hit(b *testing.B) {
|
||||
m := setupManager(b)
|
||||
ip := netip.MustParseAddr("fd00::1")
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.IsLocalIP(ip)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIsLocalIP_v6_miss(b *testing.B) {
|
||||
m := setupManager(b)
|
||||
ip := netip.MustParseAddr("2001:db8::1")
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.IsLocalIP(ip)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIsLocalIP_loopback(b *testing.B) {
|
||||
m := setupManager(b)
|
||||
ip := netip.MustParseAddr("127.0.0.1")
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.IsLocalIP(ip)
|
||||
}
|
||||
}
|
||||
@@ -72,14 +72,45 @@ func TestLocalIPManager(t *testing.T) {
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
name: "IPv6 address matches",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: netip.MustParseAddr("fe80::1"),
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||
IPv6: netip.MustParseAddr("fd00::1"),
|
||||
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("fd00::1"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 address does not match",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||
IPv6: netip.MustParseAddr("fd00::1"),
|
||||
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("fd00::99"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "No aliasing between similar IPs",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("fe80::1"),
|
||||
testIP: netip.MustParseAddr("192.168.0.17"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 loopback",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("::1"),
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -171,90 +202,3 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// MapImplementation is a version using map[string]struct{}
|
||||
type MapImplementation struct {
|
||||
localIPs map[string]struct{}
|
||||
}
|
||||
|
||||
func BenchmarkIPChecks(b *testing.B) {
|
||||
interfaces := make([]net.IP, 16)
|
||||
for i := range interfaces {
|
||||
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
|
||||
}
|
||||
|
||||
// Setup bitmap
|
||||
bitmapManager := newLocalIPManager()
|
||||
for _, ip := range interfaces[:8] { // Add half of IPs
|
||||
bitmapManager.setBitmapBit(ip)
|
||||
}
|
||||
|
||||
// Setup map version
|
||||
mapManager := &MapImplementation{
|
||||
localIPs: make(map[string]struct{}),
|
||||
}
|
||||
for _, ip := range interfaces[:8] {
|
||||
mapManager.localIPs[ip.String()] = struct{}{}
|
||||
}
|
||||
|
||||
b.Run("Bitmap_Hit", func(b *testing.B) {
|
||||
ip := interfaces[4]
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bitmapManager.checkBitmapBit(ip)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Bitmap_Miss", func(b *testing.B) {
|
||||
ip := interfaces[12]
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bitmapManager.checkBitmapBit(ip)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Map_Hit", func(b *testing.B) {
|
||||
ip := interfaces[4]
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// nolint:gosimple
|
||||
_ = mapManager.localIPs[ip.String()]
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Map_Miss", func(b *testing.B) {
|
||||
ip := interfaces[12]
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// nolint:gosimple
|
||||
_ = mapManager.localIPs[ip.String()]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkWGPosition(b *testing.B) {
|
||||
wgIP := net.ParseIP("10.10.0.1")
|
||||
|
||||
// Create two managers - one checks WG IP first, other checks it last
|
||||
b.Run("WG_First", func(b *testing.B) {
|
||||
bm := newLocalIPManager()
|
||||
bm.setBitmapBit(wgIP)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bm.checkBitmapBit(wgIP)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("WG_Last", func(b *testing.B) {
|
||||
bm := newLocalIPManager()
|
||||
// Fill with other IPs first
|
||||
for i := 0; i < 15; i++ {
|
||||
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
|
||||
}
|
||||
bm.setBitmapBit(wgIP) // Add WG IP last
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bm.checkBitmapBit(wgIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -13,8 +13,6 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
|
||||
|
||||
var (
|
||||
errInvalidIPHeaderLength = errors.New("invalid IP header length")
|
||||
)
|
||||
@@ -25,10 +23,33 @@ const (
|
||||
destinationPortOffset = 2
|
||||
|
||||
// IP address offsets in IPv4 header
|
||||
sourceIPOffset = 12
|
||||
destinationIPOffset = 16
|
||||
ipv4SrcOffset = 12
|
||||
ipv4DstOffset = 16
|
||||
|
||||
// IP address offsets in IPv6 header
|
||||
ipv6SrcOffset = 8
|
||||
ipv6DstOffset = 24
|
||||
|
||||
// IPv6 fixed header length
|
||||
ipv6HeaderLen = 40
|
||||
)
|
||||
|
||||
// ipHeaderLen returns the IP header length based on the decoded layer type.
|
||||
func ipHeaderLen(d *decoder) (int, error) {
|
||||
switch d.decoded[0] {
|
||||
case layers.LayerTypeIPv4:
|
||||
n := int(d.ip4.IHL) * 4
|
||||
if n < 20 {
|
||||
return 0, errInvalidIPHeaderLength
|
||||
}
|
||||
return n, nil
|
||||
case layers.LayerTypeIPv6:
|
||||
return ipv6HeaderLen, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown IP layer: %v", d.decoded[0])
|
||||
}
|
||||
}
|
||||
|
||||
// ipv4Checksum calculates IPv4 header checksum.
|
||||
func ipv4Checksum(header []byte) uint16 {
|
||||
if len(header) < 20 {
|
||||
@@ -234,14 +255,13 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
|
||||
_, dstIP := extractPacketIPs(packetData, d)
|
||||
translatedIP, exists := m.getDNATTranslation(dstIP)
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil {
|
||||
if err := m.rewritePacketIP(packetData, d, translatedIP, false); err != nil {
|
||||
m.logger.Error1("failed to rewrite packet destination: %v", err)
|
||||
return false
|
||||
}
|
||||
@@ -256,14 +276,13 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
|
||||
srcIP, _ := extractPacketIPs(packetData, d)
|
||||
originalIP, exists := m.findReverseDNATMapping(srcIP)
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil {
|
||||
if err := m.rewritePacketIP(packetData, d, originalIP, true); err != nil {
|
||||
m.logger.Error1("failed to rewrite packet source: %v", err)
|
||||
return false
|
||||
}
|
||||
@@ -272,38 +291,96 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums.
|
||||
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error {
|
||||
// extractPacketIPs extracts src and dst IP addresses directly from raw packet bytes.
|
||||
func extractPacketIPs(packetData []byte, d *decoder) (src, dst netip.Addr) {
|
||||
switch d.decoded[0] {
|
||||
case layers.LayerTypeIPv4:
|
||||
src = netip.AddrFrom4([4]byte{packetData[ipv4SrcOffset], packetData[ipv4SrcOffset+1], packetData[ipv4SrcOffset+2], packetData[ipv4SrcOffset+3]})
|
||||
dst = netip.AddrFrom4([4]byte{packetData[ipv4DstOffset], packetData[ipv4DstOffset+1], packetData[ipv4DstOffset+2], packetData[ipv4DstOffset+3]})
|
||||
case layers.LayerTypeIPv6:
|
||||
src = netip.AddrFrom16([16]byte(packetData[ipv6SrcOffset : ipv6SrcOffset+16]))
|
||||
dst = netip.AddrFrom16([16]byte(packetData[ipv6DstOffset : ipv6DstOffset+16]))
|
||||
}
|
||||
return src, dst
|
||||
}
|
||||
|
||||
// rewritePacketIP replaces a source (isSource=true) or destination IP address in the packet and updates checksums.
|
||||
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, isSource bool) error {
|
||||
hdrLen, err := ipHeaderLen(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch d.decoded[0] {
|
||||
case layers.LayerTypeIPv4:
|
||||
return m.rewriteIPv4(packetData, d, newIP, hdrLen, isSource)
|
||||
case layers.LayerTypeIPv6:
|
||||
return m.rewriteIPv6(packetData, d, newIP, hdrLen, isSource)
|
||||
default:
|
||||
return fmt.Errorf("unknown IP layer: %v", d.decoded[0])
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) rewriteIPv4(packetData []byte, d *decoder, newIP netip.Addr, hdrLen int, isSource bool) error {
|
||||
if !newIP.Is4() {
|
||||
return ErrIPv4Only
|
||||
return fmt.Errorf("cannot write IPv6 address into IPv4 packet")
|
||||
}
|
||||
|
||||
offset := ipv4DstOffset
|
||||
if isSource {
|
||||
offset = ipv4SrcOffset
|
||||
}
|
||||
|
||||
var oldIP [4]byte
|
||||
copy(oldIP[:], packetData[ipOffset:ipOffset+4])
|
||||
copy(oldIP[:], packetData[offset:offset+4])
|
||||
newIPBytes := newIP.As4()
|
||||
copy(packetData[offset:offset+4], newIPBytes[:])
|
||||
|
||||
copy(packetData[ipOffset:ipOffset+4], newIPBytes[:])
|
||||
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return errInvalidIPHeaderLength
|
||||
}
|
||||
|
||||
// Recalculate IPv4 header checksum
|
||||
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
||||
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
||||
binary.BigEndian.PutUint16(packetData[10:12], ipv4Checksum(packetData[:hdrLen]))
|
||||
|
||||
// Update transport checksums incrementally
|
||||
if len(d.decoded) > 1 {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
|
||||
m.updateTCPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
||||
case layers.LayerTypeUDP:
|
||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
|
||||
m.updateUDPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||
m.updateICMPChecksum(packetData, hdrLen)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) rewriteIPv6(packetData []byte, d *decoder, newIP netip.Addr, hdrLen int, isSource bool) error {
|
||||
if !newIP.Is6() {
|
||||
return fmt.Errorf("cannot write IPv4 address into IPv6 packet")
|
||||
}
|
||||
|
||||
offset := ipv6DstOffset
|
||||
if isSource {
|
||||
offset = ipv6SrcOffset
|
||||
}
|
||||
|
||||
var oldIP [16]byte
|
||||
copy(oldIP[:], packetData[offset:offset+16])
|
||||
newIPBytes := newIP.As16()
|
||||
copy(packetData[offset:offset+16], newIPBytes[:])
|
||||
|
||||
// IPv6 has no header checksum, only update transport checksums
|
||||
if len(d.decoded) > 1 {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
m.updateTCPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
||||
case layers.LayerTypeUDP:
|
||||
m.updateUDPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
||||
case layers.LayerTypeICMPv6:
|
||||
// ICMPv6 checksum includes pseudo-header with addresses, use incremental update
|
||||
m.updateICMPv6Checksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -351,6 +428,20 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
|
||||
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
|
||||
}
|
||||
|
||||
// updateICMPv6Checksum updates ICMPv6 checksum after address change.
|
||||
// ICMPv6 uses a pseudo-header (like TCP/UDP), so incremental update applies.
|
||||
func (m *Manager) updateICMPv6Checksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||
icmpStart := ipHeaderLen
|
||||
if len(packetData) < icmpStart+4 {
|
||||
return
|
||||
}
|
||||
|
||||
checksumOffset := icmpStart + 2
|
||||
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
// incrementalUpdate performs incremental checksum update per RFC 1624.
|
||||
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||
sum := uint32(^oldChecksum)
|
||||
@@ -515,12 +606,12 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti
|
||||
|
||||
// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum.
|
||||
func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return errInvalidIPHeaderLength
|
||||
hdrLen, err := ipHeaderLen(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tcpStart := ipHeaderLen
|
||||
tcpStart := hdrLen
|
||||
if len(packetData) < tcpStart+4 {
|
||||
return fmt.Errorf("packet too short for TCP header")
|
||||
}
|
||||
@@ -546,12 +637,12 @@ func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16,
|
||||
|
||||
// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum.
|
||||
func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return errInvalidIPHeaderLength
|
||||
hdrLen, err := ipHeaderLen(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
udpStart := ipHeaderLen
|
||||
udpStart := hdrLen
|
||||
if len(packetData) < udpStart+8 {
|
||||
return fmt.Errorf("packet too short for UDP header")
|
||||
}
|
||||
|
||||
@@ -342,12 +342,17 @@ func BenchmarkDNATMemoryAllocations(b *testing.B) {
|
||||
|
||||
// Parse the packet fresh each time to get a clean decoder
|
||||
d := &decoder{decoded: []gopacket.LayerType{}}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
err = d.parser.DecodeLayers(testPacket, &d.decoded)
|
||||
d.parser4.IgnoreUnsupported = true
|
||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv6,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser6.IgnoreUnsupported = true
|
||||
err = d.decodePacket(testPacket)
|
||||
assert.NoError(b, err)
|
||||
|
||||
manager.translateOutboundDNAT(testPacket, d)
|
||||
@@ -371,12 +376,17 @@ func BenchmarkDirectIPExtraction(b *testing.B) {
|
||||
b.Run("decoder_extraction", func(b *testing.B) {
|
||||
// Create decoder once for comparison
|
||||
d := &decoder{decoded: []gopacket.LayerType{}}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
err := d.parser.DecodeLayers(packet, &d.decoded)
|
||||
d.parser4.IgnoreUnsupported = true
|
||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv6,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser6.IgnoreUnsupported = true
|
||||
err := d.decodePacket(packet)
|
||||
assert.NoError(b, err)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
|
||||
@@ -86,13 +86,18 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
d.parser4.IgnoreUnsupported = true
|
||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv6,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser6.IgnoreUnsupported = true
|
||||
|
||||
err := d.parser.DecodeLayers(packetData, &d.decoded)
|
||||
err := d.decodePacket(packetData)
|
||||
require.NoError(t, err)
|
||||
return d
|
||||
}
|
||||
|
||||
@@ -112,10 +112,13 @@ func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string,
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) Build() ([]byte, error) {
|
||||
ip := p.buildIPLayer()
|
||||
pktLayers := []gopacket.SerializableLayer{ip}
|
||||
ipLayer, err := p.buildIPLayer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pktLayers := []gopacket.SerializableLayer{ipLayer}
|
||||
|
||||
transportLayer, err := p.buildTransportLayer(ip)
|
||||
transportLayer, err := p.buildTransportLayer(ipLayer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -129,30 +132,43 @@ func (p *PacketBuilder) Build() ([]byte, error) {
|
||||
return serializePacket(pktLayers)
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
|
||||
func (p *PacketBuilder) buildIPLayer() (gopacket.SerializableLayer, error) {
|
||||
if p.SrcIP.Is4() != p.DstIP.Is4() {
|
||||
return nil, fmt.Errorf("mixed address families: src=%s dst=%s", p.SrcIP, p.DstIP)
|
||||
}
|
||||
proto := getIPProtocolNumber(p.Protocol, p.SrcIP.Is6())
|
||||
if p.SrcIP.Is6() {
|
||||
return &layers.IPv6{
|
||||
Version: 6,
|
||||
HopLimit: 64,
|
||||
NextHeader: proto,
|
||||
SrcIP: p.SrcIP.AsSlice(),
|
||||
DstIP: p.DstIP.AsSlice(),
|
||||
}, nil
|
||||
}
|
||||
return &layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
||||
Protocol: proto,
|
||||
SrcIP: p.SrcIP.AsSlice(),
|
||||
DstIP: p.DstIP.AsSlice(),
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||
func (p *PacketBuilder) buildTransportLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
|
||||
switch p.Protocol {
|
||||
case "tcp":
|
||||
return p.buildTCPLayer(ip)
|
||||
return p.buildTCPLayer(ipLayer)
|
||||
case "udp":
|
||||
return p.buildUDPLayer(ip)
|
||||
return p.buildUDPLayer(ipLayer)
|
||||
case "icmp":
|
||||
return p.buildICMPLayer()
|
||||
return p.buildICMPLayer(ipLayer)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||
func (p *PacketBuilder) buildTCPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(p.SrcPort),
|
||||
DstPort: layers.TCPPort(p.DstPort),
|
||||
@@ -164,24 +180,37 @@ func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableL
|
||||
PSH: p.TCPState != nil && p.TCPState.PSH,
|
||||
URG: p.TCPState != nil && p.TCPState.URG,
|
||||
}
|
||||
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
|
||||
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
|
||||
if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
|
||||
if err := tcp.SetNetworkLayerForChecksum(nl); err != nil {
|
||||
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
|
||||
}
|
||||
}
|
||||
return []gopacket.SerializableLayer{tcp}, nil
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||
func (p *PacketBuilder) buildUDPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
|
||||
udp := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(p.SrcPort),
|
||||
DstPort: layers.UDPPort(p.DstPort),
|
||||
}
|
||||
if err := udp.SetNetworkLayerForChecksum(ip); err != nil {
|
||||
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
|
||||
if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
|
||||
if err := udp.SetNetworkLayerForChecksum(nl); err != nil {
|
||||
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
|
||||
}
|
||||
}
|
||||
return []gopacket.SerializableLayer{udp}, nil
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) {
|
||||
func (p *PacketBuilder) buildICMPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
|
||||
if p.SrcIP.Is6() || p.DstIP.Is6() {
|
||||
icmp := &layers.ICMPv6{
|
||||
TypeCode: layers.CreateICMPv6TypeCode(p.ICMPType, p.ICMPCode),
|
||||
}
|
||||
if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
|
||||
_ = icmp.SetNetworkLayerForChecksum(nl)
|
||||
}
|
||||
return []gopacket.SerializableLayer{icmp}, nil
|
||||
}
|
||||
icmp := &layers.ICMPv4{
|
||||
TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode),
|
||||
}
|
||||
@@ -204,14 +233,17 @@ func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) {
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func getIPProtocolNumber(protocol fw.Protocol) int {
|
||||
func getIPProtocolNumber(protocol fw.Protocol, isV6 bool) layers.IPProtocol {
|
||||
switch protocol {
|
||||
case fw.ProtocolTCP:
|
||||
return int(layers.IPProtocolTCP)
|
||||
return layers.IPProtocolTCP
|
||||
case fw.ProtocolUDP:
|
||||
return int(layers.IPProtocolUDP)
|
||||
return layers.IPProtocolUDP
|
||||
case fw.ProtocolICMP:
|
||||
return int(layers.IPProtocolICMPv4)
|
||||
if isV6 {
|
||||
return layers.IPProtocolICMPv6
|
||||
}
|
||||
return layers.IPProtocolICMPv4
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
@@ -234,7 +266,7 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
||||
trace := &PacketTrace{Direction: direction}
|
||||
|
||||
// Initial packet decoding
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false)
|
||||
return trace
|
||||
}
|
||||
@@ -256,6 +288,8 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
||||
trace.DestinationPort = uint16(d.udp.DstPort)
|
||||
case layers.LayerTypeICMPv4:
|
||||
trace.Protocol = "ICMP"
|
||||
case layers.LayerTypeICMPv6:
|
||||
trace.Protocol = "ICMPv6"
|
||||
}
|
||||
|
||||
trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d",
|
||||
@@ -415,7 +449,7 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
trace.AddResult(StageCompleted, "Packet dropped - decode error", false)
|
||||
return trace
|
||||
}
|
||||
@@ -434,7 +468,7 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
|
||||
func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool {
|
||||
portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d)
|
||||
if portDNATApplied {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false)
|
||||
return true
|
||||
}
|
||||
@@ -444,7 +478,7 @@ func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *de
|
||||
|
||||
nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d)
|
||||
if nat1to1Applied {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false)
|
||||
return true
|
||||
}
|
||||
@@ -509,7 +543,7 @@ func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
srcIP, _ := extractPacketIPs(packetData, d)
|
||||
|
||||
translated := m.translateInboundReverse(packetData, d)
|
||||
if translated {
|
||||
@@ -539,7 +573,7 @@ func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d
|
||||
return false
|
||||
}
|
||||
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
_, dstIP := extractPacketIPs(packetData, d)
|
||||
|
||||
translated := m.translateOutboundDNAT(packetData, d)
|
||||
if translated {
|
||||
|
||||
@@ -119,7 +119,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse endpoint address: %w", err)
|
||||
}
|
||||
addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port))
|
||||
addrPort := netip.AddrPortFrom(addr.Unmap(), uint16(endpoint.Port))
|
||||
c.activityRecorder.UpsertAddress(peerKey, addrPort)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -2,7 +2,7 @@ package device
|
||||
|
||||
// TunAdapter is an interface for create tun device from external service
|
||||
type TunAdapter interface {
|
||||
ConfigureInterface(address string, mtu int, dns string, searchDomains string, routes string) (int, error)
|
||||
ConfigureInterface(address string, addressV6 string, mtu int, dns string, searchDomains string, routes string) (int, error)
|
||||
UpdateAddr(address string) error
|
||||
ProtectSocket(fd int32) bool
|
||||
}
|
||||
|
||||
@@ -63,7 +63,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
||||
searchDomainsToString = ""
|
||||
}
|
||||
|
||||
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), int(t.mtu), dns, searchDomainsToString, routesString)
|
||||
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.address.IPv6String(), int(t.mtu), dns, searchDomainsToString, routesString)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create Android interface: %s", err)
|
||||
return nil, err
|
||||
|
||||
@@ -131,23 +131,32 @@ func (t *TunDevice) Device() *device.Device {
|
||||
|
||||
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
||||
func (t *TunDevice) assignAddr() error {
|
||||
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
log.Errorf("adding address command '%v' failed with output: %s", cmd.String(), out)
|
||||
return err
|
||||
if out, err := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String()).CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("add v4 address: %s: %w", string(out), err)
|
||||
}
|
||||
|
||||
// dummy ipv6 so routing works
|
||||
cmd = exec.Command("ifconfig", t.name, "inet6", "fe80::/64")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
log.Debugf("adding address command '%v' failed with output: %s", cmd.String(), out)
|
||||
// Assign a dummy link-local so macOS enables IPv6 on the tun device.
|
||||
// When a real overlay v6 is present, use that instead.
|
||||
v6Addr := "fe80::/64"
|
||||
if t.address.HasIPv6() {
|
||||
v6Addr = t.address.IPv6String()
|
||||
}
|
||||
if out, err := exec.Command("ifconfig", t.name, "inet6", v6Addr).CombinedOutput(); err != nil {
|
||||
log.Warnf("failed to assign IPv6 address %s, continuing v4-only: %s: %v", v6Addr, string(out), err)
|
||||
t.address.ClearIPv6()
|
||||
}
|
||||
|
||||
routeCmd := exec.Command("route", "add", "-net", t.address.Network.String(), "-interface", t.name)
|
||||
if out, err := routeCmd.CombinedOutput(); err != nil {
|
||||
log.Errorf("adding route command '%v' failed with output: %s", routeCmd.String(), out)
|
||||
return err
|
||||
if out, err := exec.Command("route", "add", "-net", t.address.Network.String(), "-interface", t.name).CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("add route %s via %s: %s: %w", t.address.Network, t.name, string(out), err)
|
||||
}
|
||||
|
||||
if t.address.HasIPv6() {
|
||||
if out, err := exec.Command("route", "add", "-inet6", "-net", t.address.IPv6Net.String(), "-interface", t.name).CombinedOutput(); err != nil {
|
||||
log.Warnf("failed to add route %s via %s, continuing v4-only: %s: %v", t.address.IPv6Net, t.name, string(out), err)
|
||||
t.address.ClearIPv6()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -151,8 +151,11 @@ func (t *TunDevice) MTU() uint16 {
|
||||
return t.mtu
|
||||
}
|
||||
|
||||
func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
|
||||
// todo implement
|
||||
// UpdateAddr updates the device address. On iOS the tunnel is managed by the
|
||||
// NetworkExtension, so we only store the new value. The extension picks up the
|
||||
// change on the next tunnel reconfiguration.
|
||||
func (t *TunDevice) UpdateAddr(addr wgaddr.Address) error {
|
||||
t.address = addr
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -173,7 +173,7 @@ func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
|
||||
|
||||
// assignAddr Adds IP address to the tunnel interface
|
||||
func (t *TunKernelDevice) assignAddr() error {
|
||||
return t.link.assignAddr(t.address)
|
||||
return t.link.assignAddr(&t.address)
|
||||
}
|
||||
|
||||
func (t *TunKernelDevice) GetNet() *netstack.Net {
|
||||
|
||||
@@ -3,6 +3,7 @@ package device
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
@@ -63,8 +64,12 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
|
||||
return nil, fmt.Errorf("last ip: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("netstack using address: %s", t.address.IP)
|
||||
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, int(t.mtu))
|
||||
addresses := []netip.Addr{t.address.IP}
|
||||
if t.address.HasIPv6() {
|
||||
addresses = append(addresses, t.address.IPv6)
|
||||
}
|
||||
log.Debugf("netstack using addresses: %v", addresses)
|
||||
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, addresses, dnsAddr, int(t.mtu))
|
||||
log.Debugf("netstack using dns address: %s", dnsAddr)
|
||||
tunIface, net, err := t.nsTun.Create()
|
||||
if err != nil {
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type USPDevice struct {
|
||||
type TunDevice struct {
|
||||
name string
|
||||
address wgaddr.Address
|
||||
port int
|
||||
@@ -30,10 +30,10 @@ type USPDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *USPDevice {
|
||||
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *TunDevice {
|
||||
log.Infof("using userspace bind mode")
|
||||
|
||||
return &USPDevice{
|
||||
return &TunDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
port: port,
|
||||
@@ -43,7 +43,7 @@ func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu
|
||||
}
|
||||
}
|
||||
|
||||
func (t *USPDevice) Create() (WGConfigurer, error) {
|
||||
func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
log.Info("create tun interface")
|
||||
tunIface, err := tun.CreateTUN(t.name, int(t.mtu))
|
||||
if err != nil {
|
||||
@@ -75,7 +75,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
|
||||
return t.configurer, nil
|
||||
}
|
||||
|
||||
func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
if t.device == nil {
|
||||
return nil, fmt.Errorf("device is not ready yet")
|
||||
}
|
||||
@@ -95,12 +95,12 @@ func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
return udpMux, nil
|
||||
}
|
||||
|
||||
func (t *USPDevice) UpdateAddr(address wgaddr.Address) error {
|
||||
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
|
||||
t.address = address
|
||||
return t.assignAddr()
|
||||
}
|
||||
|
||||
func (t *USPDevice) Close() error {
|
||||
func (t *TunDevice) Close() error {
|
||||
if t.configurer != nil {
|
||||
t.configurer.Close()
|
||||
}
|
||||
@@ -115,39 +115,39 @@ func (t *USPDevice) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *USPDevice) WgAddress() wgaddr.Address {
|
||||
func (t *TunDevice) WgAddress() wgaddr.Address {
|
||||
return t.address
|
||||
}
|
||||
|
||||
func (t *USPDevice) MTU() uint16 {
|
||||
func (t *TunDevice) MTU() uint16 {
|
||||
return t.mtu
|
||||
}
|
||||
|
||||
func (t *USPDevice) DeviceName() string {
|
||||
func (t *TunDevice) DeviceName() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
func (t *USPDevice) FilteredDevice() *FilteredDevice {
|
||||
func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
||||
return t.filteredDevice
|
||||
}
|
||||
|
||||
// Device returns the wireguard device
|
||||
func (t *USPDevice) Device() *device.Device {
|
||||
func (t *TunDevice) Device() *device.Device {
|
||||
return t.device
|
||||
}
|
||||
|
||||
// assignAddr Adds IP address to the tunnel interface
|
||||
func (t *USPDevice) assignAddr() error {
|
||||
func (t *TunDevice) assignAddr() error {
|
||||
link := newWGLink(t.name)
|
||||
|
||||
return link.assignAddr(t.address)
|
||||
return link.assignAddr(&t.address)
|
||||
}
|
||||
|
||||
func (t *USPDevice) GetNet() *netstack.Net {
|
||||
func (t *TunDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetICEBind returns the ICEBind instance
|
||||
func (t *USPDevice) GetICEBind() EndpointManager {
|
||||
func (t *TunDevice) GetICEBind() EndpointManager {
|
||||
return t.iceBind
|
||||
}
|
||||
|
||||
@@ -87,7 +87,19 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
err = nbiface.Set()
|
||||
if err != nil {
|
||||
t.device.Close()
|
||||
return nil, fmt.Errorf("got error when getting setting the interface mtu: %s", err)
|
||||
return nil, fmt.Errorf("set IPv4 interface MTU: %s", err)
|
||||
}
|
||||
|
||||
if t.address.HasIPv6() {
|
||||
nbiface6, err := luid.IPInterface(windows.AF_INET6)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get IPv6 interface for MTU: %v", err)
|
||||
} else {
|
||||
nbiface6.NLMTU = uint32(t.mtu)
|
||||
if err := nbiface6.Set(); err != nil {
|
||||
log.Warnf("failed to set IPv6 interface MTU: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
err = t.assignAddr()
|
||||
if err != nil {
|
||||
@@ -178,8 +190,21 @@ func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
|
||||
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
||||
func (t *TunDevice) assignAddr() error {
|
||||
luid := winipcfg.LUID(t.nativeTunDevice.LUID())
|
||||
log.Debugf("adding address %s to interface: %s", t.address.IP, t.name)
|
||||
return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())})
|
||||
|
||||
v4Prefix := t.address.Prefix()
|
||||
if t.address.HasIPv6() {
|
||||
v6Prefix := t.address.IPv6Prefix()
|
||||
log.Debugf("adding addresses %s, %s to interface: %s", v4Prefix, v6Prefix, t.name)
|
||||
if err := luid.SetIPAddresses([]netip.Prefix{v4Prefix, v6Prefix}); err != nil {
|
||||
log.Warnf("failed to assign dual-stack addresses, retrying v4-only: %v", err)
|
||||
t.address.ClearIPv6()
|
||||
return luid.SetIPAddresses([]netip.Prefix{v4Prefix})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debugf("adding address %s to interface: %s", v4Prefix, t.name)
|
||||
return luid.SetIPAddresses([]netip.Prefix{v4Prefix})
|
||||
}
|
||||
|
||||
func (t *TunDevice) GetNet() *netstack.Net {
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
//go:build (!linux && !freebsd) || android
|
||||
|
||||
package device
|
||||
|
||||
// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
|
||||
func WireGuardModuleIsLoaded() bool {
|
||||
return false
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
package device
|
||||
|
||||
// WireGuardModuleIsLoaded check if kernel support wireguard
|
||||
func WireGuardModuleIsLoaded() bool {
|
||||
// Despite the fact FreeBSD natively support Wireguard (https://github.com/WireGuard/wireguard-freebsd)
|
||||
// we are currently do not use it, since it is required to add wireguard kernel support to
|
||||
// - https://github.com/netbirdio/netbird/tree/main/sharedsock
|
||||
// - https://github.com/mdlayher/socket
|
||||
// TODO: implement kernel space
|
||||
return false
|
||||
}
|
||||
|
||||
// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it
|
||||
func ModuleTunIsLoaded() bool {
|
||||
// Assume tun supported by freebsd kernel by default
|
||||
// TODO: implement check for module loaded in kernel or build-it
|
||||
return true
|
||||
}
|
||||
13
client/iface/device/kernel_module_nonlinux.go
Normal file
13
client/iface/device/kernel_module_nonlinux.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package device
|
||||
|
||||
// WireGuardModuleIsLoaded reports whether the kernel WireGuard module is available.
|
||||
func WireGuardModuleIsLoaded() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// ModuleTunIsLoaded reports whether the tun device is available.
|
||||
func ModuleTunIsLoaded() bool {
|
||||
return true
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package device
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -57,32 +58,32 @@ func (l *wgLink) up() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
||||
func (l *wgLink) assignAddr(address *wgaddr.Address) error {
|
||||
link, err := freebsd.LinkByName(l.name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("link by name: %w", err)
|
||||
}
|
||||
|
||||
ip := address.IP.String()
|
||||
|
||||
// Convert prefix length to hex netmask
|
||||
prefixLen := address.Network.Bits()
|
||||
if !address.IP.Is4() {
|
||||
return fmt.Errorf("IPv6 not supported for interface assignment")
|
||||
}
|
||||
|
||||
maskBits := uint32(0xffffffff) << (32 - prefixLen)
|
||||
mask := fmt.Sprintf("0x%08x", maskBits)
|
||||
|
||||
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)
|
||||
log.Infof("assign addr %s mask %s to %s interface", address.IP, mask, l.name)
|
||||
|
||||
err = link.AssignAddr(ip, mask)
|
||||
if err != nil {
|
||||
if err := link.AssignAddr(address.IP.String(), mask); err != nil {
|
||||
return fmt.Errorf("assign addr: %w", err)
|
||||
}
|
||||
|
||||
err = link.Up()
|
||||
if err != nil {
|
||||
if address.HasIPv6() {
|
||||
log.Infof("assign IPv6 addr %s to %s interface", address.IPv6String(), l.name)
|
||||
cmd := exec.Command("ifconfig", l.name, "inet6", address.IPv6String())
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
log.Warnf("failed to assign IPv6 address %s to %s, continuing v4-only: %s: %v", address.IPv6String(), l.name, string(out), err)
|
||||
address.ClearIPv6()
|
||||
}
|
||||
}
|
||||
|
||||
if err := link.Up(); err != nil {
|
||||
return fmt.Errorf("up: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ package device
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -92,7 +94,7 @@ func (l *wgLink) up() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
||||
func (l *wgLink) assignAddr(address *wgaddr.Address) error {
|
||||
//delete existing addresses
|
||||
list, err := netlink.AddrList(l, 0)
|
||||
if err != nil {
|
||||
@@ -110,20 +112,16 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
||||
}
|
||||
|
||||
name := l.attrs.Name
|
||||
addrStr := address.String()
|
||||
|
||||
log.Debugf("adding address %s to interface: %s", addrStr, name)
|
||||
|
||||
addr, err := netlink.ParseAddr(addrStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse addr: %w", err)
|
||||
if err := l.addAddr(name, address.Prefix()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = netlink.AddrAdd(l, addr)
|
||||
if os.IsExist(err) {
|
||||
log.Infof("interface %s already has the address: %s", name, addrStr)
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("add addr: %w", err)
|
||||
if address.HasIPv6() {
|
||||
if err := l.addAddr(name, address.IPv6Prefix()); err != nil {
|
||||
log.Warnf("failed to assign IPv6 address %s to %s, continuing v4-only: %v", address.IPv6Prefix(), name, err)
|
||||
address.ClearIPv6()
|
||||
}
|
||||
}
|
||||
|
||||
// On linux, the link must be brought up
|
||||
@@ -133,3 +131,22 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *wgLink) addAddr(ifaceName string, prefix netip.Prefix) error {
|
||||
log.Debugf("adding address %s to interface: %s", prefix, ifaceName)
|
||||
|
||||
addr := &netlink.Addr{
|
||||
IPNet: &net.IPNet{
|
||||
IP: prefix.Addr().AsSlice(),
|
||||
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()),
|
||||
},
|
||||
}
|
||||
|
||||
if err := netlink.AddrAdd(l, addr); os.IsExist(err) {
|
||||
log.Infof("interface %s already has the address: %s", ifaceName, prefix)
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("add addr %s: %w", prefix, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ type wgProxyFactory interface {
|
||||
|
||||
type WGIFaceOpts struct {
|
||||
IFaceName string
|
||||
Address string
|
||||
Address wgaddr.Address
|
||||
WGPort int
|
||||
WGPrivKey string
|
||||
MTU uint16
|
||||
@@ -141,16 +141,11 @@ func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
}
|
||||
|
||||
// UpdateAddr updates address of the interface
|
||||
func (w *WGIface) UpdateAddr(newAddr string) error {
|
||||
func (w *WGIface) UpdateAddr(newAddr wgaddr.Address) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
addr, err := wgaddr.ParseWGAddress(newAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return w.tun.UpdateAddr(addr)
|
||||
return w.tun.UpdateAddr(newAddr)
|
||||
}
|
||||
|
||||
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
|
||||
|
||||
@@ -1,33 +1,28 @@
|
||||
//go:build !linux && !ios && !android && !js
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
wgaddr "github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
|
||||
var tun WGTunDevice
|
||||
if netstack.IsEnabled() {
|
||||
tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||
tun = device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||
} else {
|
||||
tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||
tun = device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||
}
|
||||
|
||||
wgIFace := &WGIface{
|
||||
return &WGIface{
|
||||
userspaceBind: true,
|
||||
tun: tun,
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||
}
|
||||
return wgIFace, nil
|
||||
|
||||
}, nil
|
||||
}
|
||||
@@ -4,23 +4,17 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
|
||||
if netstack.IsEnabled() {
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
|
||||
tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||
}
|
||||
return wgIFace, nil
|
||||
@@ -28,7 +22,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
|
||||
tun: device.NewTunDevice(opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||
}
|
||||
return wgIFace, nil
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
//go:build !ios
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
|
||||
var tun WGTunDevice
|
||||
if netstack.IsEnabled() {
|
||||
tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||
} else {
|
||||
tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||
}
|
||||
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
tun: tun,
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||
}
|
||||
return wgIFace, nil
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
//go:build freebsd
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wgIFace := &WGIface{}
|
||||
|
||||
if netstack.IsEnabled() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
if device.ModuleTunIsLoaded() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("couldn't check or load tun module")
|
||||
}
|
||||
@@ -5,21 +5,15 @@ package iface
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
|
||||
wgIFace := &WGIface{
|
||||
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
|
||||
tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
|
||||
userspaceBind: true,
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||
}
|
||||
|
||||
@@ -4,21 +4,15 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode)
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
relayBind := bind.NewRelayBindJS()
|
||||
|
||||
wgIface := &WGIface{
|
||||
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()),
|
||||
tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()),
|
||||
userspaceBind: true,
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(relayBind, opts.MTU),
|
||||
}
|
||||
|
||||
@@ -3,44 +3,40 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"errors"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wgIFace := &WGIface{}
|
||||
|
||||
if netstack.IsEnabled() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||
return wgIFace, nil
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
return &WGIface{
|
||||
tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
|
||||
userspaceBind: true,
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||
}, nil
|
||||
}
|
||||
|
||||
if device.WireGuardModuleIsLoaded() {
|
||||
wgIFace.tun = device.NewKernelDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet)
|
||||
wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort, opts.MTU)
|
||||
return wgIFace, nil
|
||||
}
|
||||
if device.ModuleTunIsLoaded() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||
return wgIFace, nil
|
||||
return &WGIface{
|
||||
tun: device.NewKernelDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet),
|
||||
wgProxyFactory: wgproxy.NewKernelFactory(opts.WGPort, opts.MTU),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("couldn't check or load tun module")
|
||||
if device.ModuleTunIsLoaded() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
return &WGIface{
|
||||
tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind),
|
||||
userspaceBind: true,
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("tun module not available")
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
@@ -48,7 +49,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
|
||||
|
||||
opts := WGIFaceOpts{
|
||||
IFaceName: ifaceName,
|
||||
Address: addr,
|
||||
Address: wgaddr.MustParseWGAddress(addr),
|
||||
WGPort: wgPort,
|
||||
WGPrivKey: key,
|
||||
MTU: DefaultMTU,
|
||||
@@ -84,7 +85,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
|
||||
|
||||
//update WireGuard address
|
||||
addr = "100.64.0.2/8"
|
||||
err = iface.UpdateAddr(addr)
|
||||
err = iface.UpdateAddr(wgaddr.MustParseWGAddress(addr))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -130,7 +131,7 @@ func Test_CreateInterface(t *testing.T) {
|
||||
}
|
||||
opts := WGIFaceOpts{
|
||||
IFaceName: ifaceName,
|
||||
Address: wgIP,
|
||||
Address: wgaddr.MustParseWGAddress(wgIP),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: key,
|
||||
MTU: DefaultMTU,
|
||||
@@ -174,7 +175,7 @@ func Test_Close(t *testing.T) {
|
||||
|
||||
opts := WGIFaceOpts{
|
||||
IFaceName: ifaceName,
|
||||
Address: wgIP,
|
||||
Address: wgaddr.MustParseWGAddress(wgIP),
|
||||
WGPort: wgPort,
|
||||
WGPrivKey: key,
|
||||
MTU: DefaultMTU,
|
||||
@@ -219,7 +220,7 @@ func TestRecreation(t *testing.T) {
|
||||
|
||||
opts := WGIFaceOpts{
|
||||
IFaceName: ifaceName,
|
||||
Address: wgIP,
|
||||
Address: wgaddr.MustParseWGAddress(wgIP),
|
||||
WGPort: wgPort,
|
||||
WGPrivKey: key,
|
||||
MTU: DefaultMTU,
|
||||
@@ -291,7 +292,7 @@ func Test_ConfigureInterface(t *testing.T) {
|
||||
}
|
||||
opts := WGIFaceOpts{
|
||||
IFaceName: ifaceName,
|
||||
Address: wgIP,
|
||||
Address: wgaddr.MustParseWGAddress(wgIP),
|
||||
WGPort: wgPort,
|
||||
WGPrivKey: key,
|
||||
MTU: DefaultMTU,
|
||||
@@ -347,7 +348,7 @@ func Test_UpdatePeer(t *testing.T) {
|
||||
|
||||
opts := WGIFaceOpts{
|
||||
IFaceName: ifaceName,
|
||||
Address: wgIP,
|
||||
Address: wgaddr.MustParseWGAddress(wgIP),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: key,
|
||||
MTU: DefaultMTU,
|
||||
@@ -417,7 +418,7 @@ func Test_RemovePeer(t *testing.T) {
|
||||
|
||||
opts := WGIFaceOpts{
|
||||
IFaceName: ifaceName,
|
||||
Address: wgIP,
|
||||
Address: wgaddr.MustParseWGAddress(wgIP),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: key,
|
||||
MTU: DefaultMTU,
|
||||
@@ -482,7 +483,7 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
|
||||
optsPeer1 := WGIFaceOpts{
|
||||
IFaceName: peer1ifaceName,
|
||||
Address: peer1wgIP.String(),
|
||||
Address: wgaddr.MustParseWGAddress(peer1wgIP.String()),
|
||||
WGPort: peer1wgPort,
|
||||
WGPrivKey: peer1Key.String(),
|
||||
MTU: DefaultMTU,
|
||||
@@ -522,7 +523,7 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
|
||||
optsPeer2 := WGIFaceOpts{
|
||||
IFaceName: peer2ifaceName,
|
||||
Address: peer2wgIP.String(),
|
||||
Address: wgaddr.MustParseWGAddress(peer2wgIP.String()),
|
||||
WGPort: peer2wgPort,
|
||||
WGPrivKey: peer2Key.String(),
|
||||
MTU: DefaultMTU,
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
|
||||
|
||||
type NetStackTun struct { //nolint:revive
|
||||
address netip.Addr
|
||||
addresses []netip.Addr
|
||||
dnsAddress netip.Addr
|
||||
mtu int
|
||||
listenAddress string
|
||||
@@ -22,9 +22,9 @@ type NetStackTun struct { //nolint:revive
|
||||
tundev tun.Device
|
||||
}
|
||||
|
||||
func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
|
||||
func NewNetStackTun(listenAddress string, addresses []netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
|
||||
return &NetStackTun{
|
||||
address: address,
|
||||
addresses: addresses,
|
||||
dnsAddress: dnsAddress,
|
||||
mtu: mtu,
|
||||
listenAddress: listenAddress,
|
||||
@@ -33,7 +33,7 @@ func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.A
|
||||
|
||||
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
||||
nsTunDev, tunNet, err := netstack.CreateNetTUN(
|
||||
[]netip.Addr{t.address},
|
||||
t.addresses,
|
||||
[]netip.Addr{t.dnsAddress},
|
||||
t.mtu)
|
||||
if err != nil {
|
||||
|
||||
@@ -3,12 +3,18 @@ package wgaddr
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
)
|
||||
|
||||
// Address WireGuard parsed address
|
||||
type Address struct {
|
||||
IP netip.Addr
|
||||
Network netip.Prefix
|
||||
|
||||
// IPv6 overlay address, if assigned.
|
||||
IPv6 netip.Addr
|
||||
IPv6Net netip.Prefix
|
||||
}
|
||||
|
||||
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
||||
@@ -23,6 +29,60 @@ func ParseWGAddress(address string) (Address, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (addr Address) String() string {
|
||||
return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits())
|
||||
// HasIPv6 reports whether a v6 overlay address is assigned.
|
||||
func (addr Address) HasIPv6() bool {
|
||||
return addr.IPv6.IsValid()
|
||||
}
|
||||
|
||||
func (addr Address) String() string {
|
||||
return addr.Prefix().String()
|
||||
}
|
||||
|
||||
// IPv6String returns the v6 address in CIDR notation, or empty string if none.
|
||||
func (addr Address) IPv6String() string {
|
||||
if !addr.HasIPv6() {
|
||||
return ""
|
||||
}
|
||||
return addr.IPv6Prefix().String()
|
||||
}
|
||||
|
||||
// Prefix returns the v4 host address with its network prefix length (e.g. 100.64.0.1/16).
|
||||
func (addr Address) Prefix() netip.Prefix {
|
||||
return netip.PrefixFrom(addr.IP, addr.Network.Bits())
|
||||
}
|
||||
|
||||
// IPv6Prefix returns the v6 host address with its network prefix length, or a zero prefix if none.
|
||||
func (addr Address) IPv6Prefix() netip.Prefix {
|
||||
if !addr.HasIPv6() {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
return netip.PrefixFrom(addr.IPv6, addr.IPv6Net.Bits())
|
||||
}
|
||||
|
||||
// SetIPv6FromCompact decodes a compact prefix (5 or 17 bytes) and sets the IPv6 fields.
|
||||
// Returns an error if the bytes are invalid. A nil or empty input is a no-op.
|
||||
//
|
||||
//nolint:recvcheck
|
||||
func (addr *Address) SetIPv6FromCompact(raw []byte) error {
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
prefix, err := netiputil.DecodePrefix(raw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode v6 overlay address: %w", err)
|
||||
}
|
||||
if !prefix.Addr().Is6() {
|
||||
return fmt.Errorf("expected IPv6 address, got %s", prefix.Addr())
|
||||
}
|
||||
addr.IPv6 = prefix.Addr()
|
||||
addr.IPv6Net = prefix.Masked()
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearIPv6 removes the IPv6 overlay address, leaving only v4.
|
||||
//
|
||||
//nolint:recvcheck
|
||||
func (addr *Address) ClearIPv6() {
|
||||
addr.IPv6 = netip.Addr{}
|
||||
addr.IPv6Net = netip.Prefix{}
|
||||
}
|
||||
|
||||
10
client/iface/wgaddr/address_test_helpers.go
Normal file
10
client/iface/wgaddr/address_test_helpers.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package wgaddr
|
||||
|
||||
// MustParseWGAddress parses and returns a WG Address, panicking on error.
|
||||
func MustParseWGAddress(address string) Address {
|
||||
a, err := ParseWGAddress(address)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return a
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -196,18 +196,22 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// fakeAddress returns a fake address that is used to as an identifier for the peer.
|
||||
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
|
||||
// fakeAddress returns a fake address that is used as an identifier for the peer.
|
||||
// The fake address is in the format of 127.1.x.x where x.x is derived from the
|
||||
// last two bytes of the peer address (works for both IPv4 and IPv6).
|
||||
func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
|
||||
octets := strings.Split(peerAddress.IP.String(), ".")
|
||||
if len(octets) != 4 {
|
||||
return nil, fmt.Errorf("invalid IP format")
|
||||
if peerAddress == nil {
|
||||
return nil, fmt.Errorf("nil peer address")
|
||||
}
|
||||
|
||||
fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3]))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse new IP: %w", err)
|
||||
addr, ok := netip.AddrFromSlice(peerAddress.IP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid IP format")
|
||||
}
|
||||
addr = addr.Unmap()
|
||||
|
||||
raw := addr.As16()
|
||||
fakeIP := netip.AddrFrom4([4]byte{127, 1, raw[14], raw[15]})
|
||||
|
||||
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
|
||||
return &netipAddr, nil
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
@@ -19,6 +18,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
)
|
||||
|
||||
var ErrSourceRangesEmpty = errors.New("sources range is empty")
|
||||
@@ -105,6 +105,10 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
newRulePairs := make(map[id.RuleID][]firewall.Rule)
|
||||
ipsetByRuleSelectors := make(map[string]string)
|
||||
|
||||
// TODO: deny rules should be fatal: if a deny rule fails to apply, we must
|
||||
// roll back all allow rules to avoid a fail-open where allowed traffic bypasses
|
||||
// the missing deny. Currently we accumulate errors and continue.
|
||||
var merr *multierror.Error
|
||||
for _, r := range rules {
|
||||
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
||||
// it's IP address can be used in the ipset for firewall manager which supports it
|
||||
@@ -117,9 +121,8 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
}
|
||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
||||
if err != nil {
|
||||
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
||||
d.rollBack(newRulePairs)
|
||||
break
|
||||
merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err))
|
||||
continue
|
||||
}
|
||||
if len(rulePair) > 0 {
|
||||
d.peerRulesPairs[pairID] = rulePair
|
||||
@@ -127,6 +130,10 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
}
|
||||
}
|
||||
|
||||
if merr != nil {
|
||||
log.Errorf("failed to apply %d peer ACL rule(s): %v", merr.Len(), nberrors.FormatErrorOrNil(merr))
|
||||
}
|
||||
|
||||
for pairID, rules := range d.peerRulesPairs {
|
||||
if _, ok := newRulePairs[pairID]; !ok {
|
||||
for _, rule := range rules {
|
||||
@@ -216,9 +223,9 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
r *mgmProto.FirewallRule,
|
||||
ipsetName string,
|
||||
) (id.RuleID, []firewall.Rule, error) {
|
||||
ip := net.ParseIP(r.PeerIP)
|
||||
if ip == nil {
|
||||
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
ip, err := extractRuleIP(r)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
protocol, err := convertToFirewallProtocol(r.Protocol)
|
||||
@@ -289,13 +296,13 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
||||
|
||||
func (d *DefaultManager) addInRules(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
ip netip.Addr,
|
||||
protocol firewall.Protocol,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, nil, port, action, ipsetName)
|
||||
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, nil, port, action, ipsetName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
@@ -305,7 +312,7 @@ func (d *DefaultManager) addInRules(
|
||||
|
||||
func (d *DefaultManager) addOutRules(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
ip netip.Addr,
|
||||
protocol firewall.Protocol,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
@@ -315,7 +322,7 @@ func (d *DefaultManager) addOutRules(
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, port, nil, action, ipsetName)
|
||||
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, port, nil, action, ipsetName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
@@ -323,9 +330,9 @@ func (d *DefaultManager) addOutRules(
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// getPeerRuleID() returns unique ID for the rule based on its parameters.
|
||||
// getPeerRuleID returns unique ID for the rule based on its parameters.
|
||||
func (d *DefaultManager) getPeerRuleID(
|
||||
ip net.IP,
|
||||
ip netip.Addr,
|
||||
proto firewall.Protocol,
|
||||
direction int,
|
||||
port *firewall.Port,
|
||||
@@ -344,15 +351,25 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
|
||||
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
|
||||
}
|
||||
|
||||
func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
|
||||
log.Debugf("rollback ACL to previous state")
|
||||
for _, rules := range newRulePairs {
|
||||
for _, rule := range rules {
|
||||
if err := d.firewall.DeletePeerRule(rule); err != nil {
|
||||
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.ID(), err)
|
||||
}
|
||||
|
||||
// extractRuleIP extracts the peer IP from a firewall rule.
|
||||
// If sourcePrefixes is populated (new management), decode the first entry and use its address.
|
||||
// Otherwise fall back to the deprecated PeerIP string field (old management).
|
||||
func extractRuleIP(r *mgmProto.FirewallRule) (netip.Addr, error) {
|
||||
if len(r.SourcePrefixes) > 0 {
|
||||
addr, err := netiputil.DecodeAddr(r.SourcePrefixes[0])
|
||||
if err != nil {
|
||||
return netip.Addr{}, fmt.Errorf("decode source prefix: %w", err)
|
||||
}
|
||||
return addr.Unmap(), nil
|
||||
}
|
||||
|
||||
//nolint:staticcheck // PeerIP used for backward compatibility with old management
|
||||
addr, err := netip.ParseAddr(r.PeerIP)
|
||||
if err != nil {
|
||||
return netip.Addr{}, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
}
|
||||
return addr.Unmap(), nil
|
||||
}
|
||||
|
||||
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
|
||||
|
||||
@@ -345,6 +345,7 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
||||
a.config.DisableFirewall,
|
||||
a.config.BlockLANAccess,
|
||||
a.config.BlockInbound,
|
||||
a.config.DisableIPv6,
|
||||
a.config.LazyConnectionEnabled,
|
||||
a.config.EnableSSHRoot,
|
||||
a.config.EnableSSHSFTP,
|
||||
|
||||
@@ -14,10 +14,13 @@ import (
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
@@ -520,9 +523,20 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
||||
if config.NetworkMonitor != nil {
|
||||
nm = *config.NetworkMonitor
|
||||
}
|
||||
wgAddr, err := wgaddr.ParseWGAddress(peerConfig.Address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse overlay address %q: %w", peerConfig.Address, err)
|
||||
}
|
||||
|
||||
if !config.DisableIPv6 {
|
||||
if err := wgAddr.SetIPv6FromCompact(peerConfig.GetAddressV6()); err != nil {
|
||||
log.Warn(err)
|
||||
}
|
||||
}
|
||||
|
||||
engineConf := &EngineConfig{
|
||||
WgIfaceName: config.WgIface,
|
||||
WgAddr: peerConfig.Address,
|
||||
WgAddr: wgAddr,
|
||||
IFaceBlackList: config.IFaceBlackList,
|
||||
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
||||
WgPrivateKey: key,
|
||||
@@ -547,6 +561,7 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
||||
DisableFirewall: config.DisableFirewall,
|
||||
BlockLANAccess: config.BlockLANAccess,
|
||||
BlockInbound: config.BlockInbound,
|
||||
DisableIPv6: config.DisableIPv6,
|
||||
|
||||
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
||||
|
||||
@@ -627,6 +642,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.DisableFirewall,
|
||||
config.BlockLANAccess,
|
||||
config.BlockInbound,
|
||||
config.DisableIPv6,
|
||||
config.LazyConnectionEnabled,
|
||||
config.EnableSSHRoot,
|
||||
config.EnableSSHSFTP,
|
||||
|
||||
@@ -522,6 +522,7 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
||||
configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall))
|
||||
configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess))
|
||||
configContent.WriteString(fmt.Sprintf("BlockInbound: %v\n", g.internalConfig.BlockInbound))
|
||||
configContent.WriteString(fmt.Sprintf("DisableIPv6: %v\n", g.internalConfig.DisableIPv6))
|
||||
|
||||
if g.internalConfig.DisableNotifications != nil {
|
||||
configContent.WriteString(fmt.Sprintf("DisableNotifications: %v\n", *g.internalConfig.DisableNotifications))
|
||||
@@ -1270,8 +1271,9 @@ func anonymizeFirewallRule(rule *mgmProto.FirewallRule, anonymizer *anonymize.An
|
||||
return
|
||||
}
|
||||
|
||||
//nolint:staticcheck // PeerIP used for backward compatibility
|
||||
if addr, err := netip.ParseAddr(rule.PeerIP); err == nil {
|
||||
rule.PeerIP = anonymizer.AnonymizeIP(addr).String()
|
||||
rule.PeerIP = anonymizer.AnonymizeIP(addr).String() //nolint:staticcheck
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -430,8 +430,6 @@ func isInCGNATRange(ip net.IP) bool {
|
||||
}
|
||||
|
||||
func TestAnonymizeFirewallRules(t *testing.T) {
|
||||
// TODO: Add ipv6
|
||||
|
||||
// Example iptables-save output
|
||||
iptablesSave := `# Generated by iptables-save v1.8.7 on Thu Dec 19 10:00:00 2024
|
||||
*filter
|
||||
@@ -467,17 +465,31 @@ Chain FORWARD (policy ACCEPT 0 packets, 0 bytes)
|
||||
Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
|
||||
pkts bytes target prot opt in out source destination`
|
||||
|
||||
// Example nftables output
|
||||
// Example ip6tables-save output
|
||||
ip6tablesSave := `# Generated by ip6tables-save v1.8.7 on Thu Dec 19 10:00:00 2024
|
||||
*filter
|
||||
:INPUT ACCEPT [0:0]
|
||||
:FORWARD ACCEPT [0:0]
|
||||
:OUTPUT ACCEPT [0:0]
|
||||
-A INPUT -s fd00:1234::1/128 -j ACCEPT
|
||||
-A INPUT -s 2607:f8b0:4005::1/128 -j DROP
|
||||
-A FORWARD -s 2001:db8::/32 -d 2607:f8b0:4005::200e/128 -j ACCEPT
|
||||
COMMIT`
|
||||
|
||||
// Example nftables output with IPv6
|
||||
nftablesRules := `table inet filter {
|
||||
chain input {
|
||||
type filter hook input priority filter; policy accept;
|
||||
ip saddr 192.168.1.1 accept
|
||||
ip saddr 44.192.140.1 drop
|
||||
ip6 saddr 2607:f8b0:4005::1 drop
|
||||
ip6 saddr fd00:1234::1 accept
|
||||
}
|
||||
chain forward {
|
||||
type filter hook forward priority filter; policy accept;
|
||||
ip saddr 10.0.0.0/8 drop
|
||||
ip saddr 44.192.140.0/24 ip daddr 52.84.12.34/24 accept
|
||||
ip6 saddr 2001:db8::/32 ip6 daddr 2607:f8b0:4005::200e/128 accept
|
||||
}
|
||||
}`
|
||||
|
||||
@@ -540,4 +552,35 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
|
||||
assert.Contains(t, anonNftables, "table inet filter {")
|
||||
assert.Contains(t, anonNftables, "chain input {")
|
||||
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
|
||||
|
||||
// IPv6 public addresses in nftables should be anonymized
|
||||
assert.NotContains(t, anonNftables, "2607:f8b0:4005::1")
|
||||
assert.NotContains(t, anonNftables, "2607:f8b0:4005::200e")
|
||||
assert.NotContains(t, anonNftables, "2001:db8::")
|
||||
assert.Contains(t, anonNftables, "2001:db8:ffff::") // Default anonymous v6 range
|
||||
|
||||
// ULA addresses in nftables should remain unchanged (private)
|
||||
assert.Contains(t, anonNftables, "fd00:1234::1")
|
||||
|
||||
// IPv6 nftables structure preserved
|
||||
assert.Contains(t, anonNftables, "ip6 saddr")
|
||||
assert.Contains(t, anonNftables, "ip6 daddr")
|
||||
|
||||
// Test ip6tables-save anonymization
|
||||
anonIp6tablesSave := anonymizer.AnonymizeString(ip6tablesSave)
|
||||
|
||||
// ULA (private) IPv6 should remain unchanged
|
||||
assert.Contains(t, anonIp6tablesSave, "fd00:1234::1/128")
|
||||
|
||||
// Public IPv6 addresses should be anonymized
|
||||
assert.NotContains(t, anonIp6tablesSave, "2607:f8b0:4005::1")
|
||||
assert.NotContains(t, anonIp6tablesSave, "2607:f8b0:4005::200e")
|
||||
assert.NotContains(t, anonIp6tablesSave, "2001:db8::")
|
||||
assert.Contains(t, anonIp6tablesSave, "2001:db8:ffff::") // Default anonymous v6 range
|
||||
|
||||
// Structure should be preserved
|
||||
assert.Contains(t, anonIp6tablesSave, "*filter")
|
||||
assert.Contains(t, anonIp6tablesSave, "COMMIT")
|
||||
assert.Contains(t, anonIp6tablesSave, "-j DROP")
|
||||
assert.Contains(t, anonIp6tablesSave, "-j ACCEPT")
|
||||
}
|
||||
|
||||
@@ -12,52 +12,83 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
func createPTRRecord(aRecord nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) {
|
||||
ip, err := netip.ParseAddr(aRecord.RData)
|
||||
func createPTRRecord(record nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) {
|
||||
ip, err := netip.ParseAddr(record.RData)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse IP address %s: %v", aRecord.RData, err)
|
||||
log.Warnf("failed to parse IP address %s: %v", record.RData, err)
|
||||
return nbdns.SimpleRecord{}, false
|
||||
}
|
||||
|
||||
ip = ip.Unmap()
|
||||
if !prefix.Contains(ip) {
|
||||
return nbdns.SimpleRecord{}, false
|
||||
}
|
||||
|
||||
ipOctets := strings.Split(ip.String(), ".")
|
||||
slices.Reverse(ipOctets)
|
||||
rdnsName := dns.Fqdn(strings.Join(ipOctets, ".") + ".in-addr.arpa")
|
||||
var rdnsName string
|
||||
if ip.Is4() {
|
||||
octets := strings.Split(ip.String(), ".")
|
||||
slices.Reverse(octets)
|
||||
rdnsName = dns.Fqdn(strings.Join(octets, ".") + ".in-addr.arpa")
|
||||
} else {
|
||||
// Expand to full 32 nibbles in reverse order (LSB first) per RFC 3596.
|
||||
raw := ip.As16()
|
||||
nibbles := make([]string, 32)
|
||||
for i := 0; i < 16; i++ {
|
||||
nibbles[31-i*2] = fmt.Sprintf("%x", raw[i]>>4)
|
||||
nibbles[31-i*2-1] = fmt.Sprintf("%x", raw[i]&0x0f)
|
||||
}
|
||||
rdnsName = dns.Fqdn(strings.Join(nibbles, ".") + ".ip6.arpa")
|
||||
}
|
||||
|
||||
return nbdns.SimpleRecord{
|
||||
Name: rdnsName,
|
||||
Type: int(dns.TypePTR),
|
||||
Class: aRecord.Class,
|
||||
TTL: aRecord.TTL,
|
||||
RData: dns.Fqdn(aRecord.Name),
|
||||
Class: record.Class,
|
||||
TTL: record.TTL,
|
||||
RData: dns.Fqdn(record.Name),
|
||||
}, true
|
||||
}
|
||||
|
||||
// generateReverseZoneName creates the reverse DNS zone name for a given network
|
||||
// generateReverseZoneName creates the reverse DNS zone name for a given network.
|
||||
// For IPv4 it produces an in-addr.arpa name, for IPv6 an ip6.arpa name.
|
||||
func generateReverseZoneName(network netip.Prefix) (string, error) {
|
||||
networkIP := network.Masked().Addr()
|
||||
networkIP := network.Masked().Addr().Unmap()
|
||||
bits := network.Bits()
|
||||
|
||||
if !networkIP.Is4() {
|
||||
return "", fmt.Errorf("reverse DNS is only supported for IPv4 networks, got: %s", networkIP)
|
||||
if networkIP.Is4() {
|
||||
// Round up to nearest byte.
|
||||
octetsToUse := (bits + 7) / 8
|
||||
|
||||
octets := strings.Split(networkIP.String(), ".")
|
||||
if octetsToUse > len(octets) {
|
||||
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", bits)
|
||||
}
|
||||
|
||||
reverseOctets := make([]string, octetsToUse)
|
||||
for i := 0; i < octetsToUse; i++ {
|
||||
reverseOctets[octetsToUse-1-i] = octets[i]
|
||||
}
|
||||
|
||||
return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil
|
||||
}
|
||||
|
||||
// round up to nearest byte
|
||||
octetsToUse := (network.Bits() + 7) / 8
|
||||
// IPv6: round up to nearest nibble (4-bit boundary).
|
||||
nibblesToUse := (bits + 3) / 4
|
||||
|
||||
octets := strings.Split(networkIP.String(), ".")
|
||||
if octetsToUse > len(octets) {
|
||||
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", network.Bits())
|
||||
raw := networkIP.As16()
|
||||
allNibbles := make([]string, 32)
|
||||
for i := 0; i < 16; i++ {
|
||||
allNibbles[i*2] = fmt.Sprintf("%x", raw[i]>>4)
|
||||
allNibbles[i*2+1] = fmt.Sprintf("%x", raw[i]&0x0f)
|
||||
}
|
||||
|
||||
reverseOctets := make([]string, octetsToUse)
|
||||
for i := 0; i < octetsToUse; i++ {
|
||||
reverseOctets[octetsToUse-1-i] = octets[i]
|
||||
// Take the first nibblesToUse nibbles (network portion), reverse them.
|
||||
used := make([]string, nibblesToUse)
|
||||
for i := 0; i < nibblesToUse; i++ {
|
||||
used[nibblesToUse-1-i] = allNibbles[i]
|
||||
}
|
||||
|
||||
return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil
|
||||
return dns.Fqdn(strings.Join(used, ".") + ".ip6.arpa"), nil
|
||||
}
|
||||
|
||||
// zoneExists checks if a zone with the given name already exists in the configuration
|
||||
@@ -71,7 +102,7 @@ func zoneExists(config *nbdns.Config, zoneName string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// collectPTRRecords gathers all PTR records for the given network from A records
|
||||
// collectPTRRecords gathers all PTR records for the given network from A and AAAA records.
|
||||
func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.SimpleRecord {
|
||||
var records []nbdns.SimpleRecord
|
||||
|
||||
@@ -80,7 +111,7 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
|
||||
continue
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
if record.Type != int(dns.TypeA) {
|
||||
if record.Type != int(dns.TypeA) && record.Type != int(dns.TypeAAAA) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -298,6 +298,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
||||
if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() {
|
||||
ip = ip.Unmap()
|
||||
serverAddresses = append(serverAddresses, ip)
|
||||
// Prefer the first IPv4 server as ServerIP since our DNS listener is IPv4.
|
||||
if !dnsSettings.ServerIP.IsValid() && ip.Is4() {
|
||||
dnsSettings.ServerIP = ip
|
||||
}
|
||||
|
||||
@@ -110,8 +110,15 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st
|
||||
|
||||
connSettings.cleanDeprecatedSettings()
|
||||
|
||||
convDNSIP := binary.LittleEndian.Uint32(config.ServerIP.AsSlice())
|
||||
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP})
|
||||
ipKey := networkManagerDbusIPv4Key
|
||||
if config.ServerIP.Is6() {
|
||||
ipKey = networkManagerDbusIPv6Key
|
||||
raw := config.ServerIP.As16()
|
||||
connSettings[ipKey][networkManagerDbusDNSKey] = dbus.MakeVariant([][]byte{raw[:]})
|
||||
} else {
|
||||
convDNSIP := binary.LittleEndian.Uint32(config.ServerIP.AsSlice())
|
||||
connSettings[ipKey][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP})
|
||||
}
|
||||
var (
|
||||
searchDomains []string
|
||||
matchDomains []string
|
||||
@@ -146,8 +153,8 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st
|
||||
n.routingAll = false
|
||||
}
|
||||
|
||||
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority)
|
||||
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList)
|
||||
connSettings[ipKey][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority)
|
||||
connSettings[ipKey][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList)
|
||||
|
||||
state := &ShutdownState{
|
||||
ManagerType: networkManager,
|
||||
|
||||
@@ -347,7 +347,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: fmt.Sprintf("utun230%d", n),
|
||||
Address: fmt.Sprintf("100.66.100.%d/32", n+1),
|
||||
Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
@@ -448,7 +448,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: "utun2301",
|
||||
Address: "100.66.100.1/32",
|
||||
Address: wgaddr.MustParseWGAddress("100.66.100.1/32"),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
@@ -929,7 +929,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: "utun2301",
|
||||
Address: "100.66.100.2/24",
|
||||
Address: wgaddr.MustParseWGAddress("100.66.100.2/24"),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
|
||||
@@ -140,10 +140,10 @@ func (s *serviceViaListener) setListenerStatus(running bool) {
|
||||
s.listenerIsRunning = running
|
||||
}
|
||||
|
||||
// evalListenAddress figure out the listen address for the DNS server
|
||||
// first check the 53 port availability on WG interface or lo, if not success
|
||||
// pick a random port on WG interface for eBPF, if not success
|
||||
// check the 5053 port availability on WG interface or lo without eBPF usage,
|
||||
// evalListenAddress figures out the listen address for the DNS server.
|
||||
// IPv4-only: all peers have a v4 overlay address, and DNS config points to v4.
|
||||
// First checks port 53 on WG interface or lo, then tries eBPF on a random port,
|
||||
// then falls back to port 5053.
|
||||
func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) {
|
||||
if s.customAddr != nil {
|
||||
return s.customAddr.Addr(), s.customAddr.Port(), nil
|
||||
@@ -219,7 +219,7 @@ func (s *serviceViaListener) tryToUseeBPF() (ebpfMgr.Manager, uint16, bool) {
|
||||
}
|
||||
|
||||
ebpfSrv := ebpf.GetEbpfManagerInstance()
|
||||
err = ebpfSrv.LoadDNSFwd(s.wgInterface.Address().IP.String(), int(port))
|
||||
err = ebpfSrv.LoadDNSFwd(s.wgInterface.Address().IP, int(port))
|
||||
if err != nil {
|
||||
log.Warnf("failed to load DNS forwarder eBPF program, error: %s", err)
|
||||
return nil, 0, false
|
||||
|
||||
@@ -90,8 +90,12 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool {
|
||||
}
|
||||
|
||||
func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||
family := int32(unix.AF_INET)
|
||||
if config.ServerIP.Is6() {
|
||||
family = unix.AF_INET6
|
||||
}
|
||||
defaultLinkInput := systemdDbusDNSInput{
|
||||
Family: unix.AF_INET,
|
||||
Family: family,
|
||||
Address: config.ServerIP.AsSlice(),
|
||||
}
|
||||
if err := s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}); err != nil {
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
@@ -29,6 +30,12 @@ import (
|
||||
|
||||
var currentMTU uint16 = iface.DefaultMTU
|
||||
|
||||
// privateClientIface is the subset of the WireGuard interface needed by GetClientPrivate.
|
||||
type privateClientIface interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
}
|
||||
|
||||
func SetCurrentMTU(mtu uint16) {
|
||||
currentMTU = mtu
|
||||
}
|
||||
|
||||
@@ -86,7 +86,7 @@ func (u *upstreamResolver) isLocalResolver(upstream string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||
func GetClientPrivate(_ privateClientIface, _ netip.Addr, dialTimeout time.Duration) (*dns.Client, error) {
|
||||
return &dns.Client{
|
||||
Timeout: dialTimeout,
|
||||
Net: "udp",
|
||||
|
||||
@@ -52,7 +52,7 @@ func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns
|
||||
return ExchangeWithFallback(ctx, client, r, upstream)
|
||||
}
|
||||
|
||||
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||
func GetClientPrivate(_ privateClientIface, _ netip.Addr, dialTimeout time.Duration) (*dns.Client, error) {
|
||||
return &dns.Client{
|
||||
Timeout: dialTimeout,
|
||||
Net: "udp",
|
||||
|
||||
@@ -19,9 +19,7 @@ import (
|
||||
|
||||
type upstreamResolverIOS struct {
|
||||
*upstreamResolverBase
|
||||
lIP netip.Addr
|
||||
lNet netip.Prefix
|
||||
interfaceName string
|
||||
wgIface WGIface
|
||||
}
|
||||
|
||||
func newUpstreamResolver(
|
||||
@@ -35,9 +33,7 @@ func newUpstreamResolver(
|
||||
|
||||
ios := &upstreamResolverIOS{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
lIP: wgIface.Address().IP,
|
||||
lNet: wgIface.Address().Network,
|
||||
interfaceName: wgIface.Name(),
|
||||
wgIface: wgIface,
|
||||
}
|
||||
ios.upstreamClient = ios
|
||||
|
||||
@@ -65,11 +61,13 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
} else {
|
||||
upstreamIP = upstreamIP.Unmap()
|
||||
}
|
||||
needsPrivate := u.lNet.Contains(upstreamIP) ||
|
||||
addr := u.wgIface.Address()
|
||||
needsPrivate := addr.Network.Contains(upstreamIP) ||
|
||||
addr.IPv6Net.Contains(upstreamIP) ||
|
||||
(u.routeMatch != nil && u.routeMatch(upstreamIP))
|
||||
if needsPrivate {
|
||||
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
|
||||
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
|
||||
client, err = GetClientPrivate(u.wgIface, upstreamIP, timeout)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("create private client: %s", err)
|
||||
}
|
||||
@@ -79,25 +77,33 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
return ExchangeWithFallback(nil, client, r, upstream)
|
||||
}
|
||||
|
||||
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
||||
// This method is needed for iOS
|
||||
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||
index, err := getInterfaceIndex(interfaceName)
|
||||
// GetClientPrivate returns a new DNS client bound to the local IP of the Netbird interface.
|
||||
// It selects the v6 bind address when the upstream is IPv6 and the interface has one, otherwise v4.
|
||||
func GetClientPrivate(iface privateClientIface, upstreamIP netip.Addr, dialTimeout time.Duration) (*dns.Client, error) {
|
||||
index, err := getInterfaceIndex(iface.Name())
|
||||
if err != nil {
|
||||
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
|
||||
log.Debugf("unable to get interface index for %s: %s", iface.Name(), err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
addr := iface.Address()
|
||||
bindIP := addr.IP
|
||||
if upstreamIP.Is6() && addr.HasIPv6() {
|
||||
bindIP = addr.IPv6
|
||||
}
|
||||
|
||||
proto, opt := unix.IPPROTO_IP, unix.IP_BOUND_IF
|
||||
if bindIP.Is6() {
|
||||
proto, opt = unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{
|
||||
LocalAddr: &net.UDPAddr{
|
||||
IP: ip.AsSlice(),
|
||||
Port: 0, // Let the OS pick a free port
|
||||
},
|
||||
Timeout: dialTimeout,
|
||||
LocalAddr: net.UDPAddrFromAddrPort(netip.AddrPortFrom(bindIP, 0)),
|
||||
Timeout: dialTimeout,
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
var operr error
|
||||
fn := func(s uintptr) {
|
||||
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, index)
|
||||
operr = unix.SetsockoptInt(int(s), proto, opt, index)
|
||||
}
|
||||
|
||||
if err := c.Control(fn); err != nil {
|
||||
|
||||
138
client/internal/dns_test.go
Normal file
138
client/internal/dns_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
func TestCreatePTRRecord_IPv4(t *testing.T) {
|
||||
record := nbdns.SimpleRecord{
|
||||
Name: "peer1.netbird.cloud.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "100.64.0.5",
|
||||
}
|
||||
prefix := netip.MustParsePrefix("100.64.0.0/16")
|
||||
|
||||
ptr, ok := createPTRRecord(record, prefix)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "5.0.64.100.in-addr.arpa.", ptr.Name)
|
||||
assert.Equal(t, int(dns.TypePTR), ptr.Type)
|
||||
assert.Equal(t, "peer1.netbird.cloud.", ptr.RData)
|
||||
}
|
||||
|
||||
func TestCreatePTRRecord_IPv6(t *testing.T) {
|
||||
record := nbdns.SimpleRecord{
|
||||
Name: "peer1.netbird.cloud.",
|
||||
Type: int(dns.TypeAAAA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "fd00:1234:5678::1",
|
||||
}
|
||||
prefix := netip.MustParsePrefix("fd00:1234:5678::/48")
|
||||
|
||||
ptr, ok := createPTRRecord(record, prefix)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.7.6.5.4.3.2.1.0.0.d.f.ip6.arpa.", ptr.Name)
|
||||
assert.Equal(t, int(dns.TypePTR), ptr.Type)
|
||||
assert.Equal(t, "peer1.netbird.cloud.", ptr.RData)
|
||||
}
|
||||
|
||||
func TestCreatePTRRecord_OutOfRange(t *testing.T) {
|
||||
record := nbdns.SimpleRecord{
|
||||
Name: "peer1.netbird.cloud.",
|
||||
Type: int(dns.TypeA),
|
||||
RData: "10.0.0.1",
|
||||
}
|
||||
prefix := netip.MustParsePrefix("100.64.0.0/16")
|
||||
|
||||
_, ok := createPTRRecord(record, prefix)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestGenerateReverseZoneName_IPv4(t *testing.T) {
|
||||
tests := []struct {
|
||||
prefix string
|
||||
expected string
|
||||
}{
|
||||
{"100.64.0.0/16", "64.100.in-addr.arpa."},
|
||||
{"10.0.0.0/8", "10.in-addr.arpa."},
|
||||
{"192.168.1.0/24", "1.168.192.in-addr.arpa."},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.prefix, func(t *testing.T) {
|
||||
zone, err := generateReverseZoneName(netip.MustParsePrefix(tt.prefix))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, zone)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateReverseZoneName_IPv6(t *testing.T) {
|
||||
tests := []struct {
|
||||
prefix string
|
||||
expected string
|
||||
}{
|
||||
{"fd00:1234:5678::/48", "8.7.6.5.4.3.2.1.0.0.d.f.ip6.arpa."},
|
||||
{"fd00::/16", "0.0.d.f.ip6.arpa."},
|
||||
{"fd12:3456:789a:bcde::/64", "e.d.c.b.a.9.8.7.6.5.4.3.2.1.d.f.ip6.arpa."},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.prefix, func(t *testing.T) {
|
||||
zone, err := generateReverseZoneName(netip.MustParsePrefix(tt.prefix))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, zone)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectPTRRecords_BothFamilies(t *testing.T) {
|
||||
config := &nbdns.Config{
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "peer1.netbird.cloud.", Type: int(dns.TypeA), RData: "100.64.0.1"},
|
||||
{Name: "peer1.netbird.cloud.", Type: int(dns.TypeAAAA), RData: "fd00::1"},
|
||||
{Name: "peer2.netbird.cloud.", Type: int(dns.TypeA), RData: "100.64.0.2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
v4Records := collectPTRRecords(config, netip.MustParsePrefix("100.64.0.0/16"))
|
||||
assert.Len(t, v4Records, 2, "should collect 2 A record PTRs for the v4 prefix")
|
||||
|
||||
v6Records := collectPTRRecords(config, netip.MustParsePrefix("fd00::/64"))
|
||||
assert.Len(t, v6Records, 1, "should collect 1 AAAA record PTR for the v6 prefix")
|
||||
}
|
||||
|
||||
func TestAddReverseZone_IPv6(t *testing.T) {
|
||||
config := &nbdns.Config{
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "peer1.netbird.cloud.", Type: int(dns.TypeAAAA), RData: "fd00:1234:5678::1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
addReverseZone(config, netip.MustParsePrefix("fd00:1234:5678::/48"))
|
||||
|
||||
require.Len(t, config.CustomZones, 2)
|
||||
reverseZone := config.CustomZones[1]
|
||||
assert.Equal(t, "8.7.6.5.4.3.2.1.0.0.d.f.ip6.arpa.", reverseZone.Domain)
|
||||
assert.Len(t, reverseZone.Records, 1)
|
||||
assert.Equal(t, int(dns.TypePTR), reverseZone.Records[0].Type)
|
||||
}
|
||||
@@ -80,6 +80,7 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// IPv4-only: peers reach the forwarder via its v4 overlay address.
|
||||
localAddr := m.wgIface.Address().IP
|
||||
|
||||
if localAddr.IsValid() && m.firewall != nil {
|
||||
|
||||
@@ -2,7 +2,8 @@ package ebpf
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
@@ -12,7 +13,7 @@ const (
|
||||
mapKeyDNSPort uint32 = 1
|
||||
)
|
||||
|
||||
func (tf *GeneralManager) LoadDNSFwd(ip string, dnsPort int) error {
|
||||
func (tf *GeneralManager) LoadDNSFwd(ip netip.Addr, dnsPort int) error {
|
||||
log.Debugf("load eBPF DNS forwarder, watching addr: %s:53, redirect to port: %d", ip, dnsPort)
|
||||
tf.lock.Lock()
|
||||
defer tf.lock.Unlock()
|
||||
@@ -22,7 +23,11 @@ func (tf *GeneralManager) LoadDNSFwd(ip string, dnsPort int) error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tf.bpfObjs.NbMapDnsIp.Put(mapKeyDNSIP, ip2int(ip))
|
||||
if !ip.Is4() {
|
||||
return fmt.Errorf("eBPF DNS forwarder only supports IPv4, got %s", ip)
|
||||
}
|
||||
ip4 := ip.As4()
|
||||
err = tf.bpfObjs.NbMapDnsIp.Put(mapKeyDNSIP, binary.BigEndian.Uint32(ip4[:]))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -45,7 +50,3 @@ func (tf *GeneralManager) FreeDNSFwd() error {
|
||||
return tf.unsetFeatureFlag(featureFlagDnsForwarder)
|
||||
}
|
||||
|
||||
func ip2int(ipString string) uint32 {
|
||||
ip := net.ParseIP(ipString)
|
||||
return binary.BigEndian.Uint32(ip.To4())
|
||||
}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package manager
|
||||
|
||||
import "net/netip"
|
||||
|
||||
// Manager is used to load multiple eBPF programs. E.g., current DNS programs and WireGuard proxy
|
||||
type Manager interface {
|
||||
LoadDNSFwd(ip string, dnsPort int) error
|
||||
LoadDNSFwd(ip netip.Addr, dnsPort int) error
|
||||
FreeDNSFwd() error
|
||||
LoadWgProxy(proxyPort, wgPort int) error
|
||||
FreeWGProxy() error
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
@@ -61,6 +62,7 @@ import (
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
@@ -84,8 +86,9 @@ type EngineConfig struct {
|
||||
WgPort int
|
||||
WgIfaceName string
|
||||
|
||||
// WgAddr is a Wireguard local address (Netbird Network IP)
|
||||
WgAddr string
|
||||
// WgAddr is the Wireguard local address (Netbird Network IP).
|
||||
// Contains both v4 and optional v6 overlay addresses.
|
||||
WgAddr wgaddr.Address
|
||||
|
||||
// WgPrivateKey is a Wireguard private key of our peer (it MUST never leave the machine)
|
||||
WgPrivateKey wgtypes.Key
|
||||
@@ -130,6 +133,7 @@ type EngineConfig struct {
|
||||
DisableFirewall bool
|
||||
BlockLANAccess bool
|
||||
BlockInbound bool
|
||||
DisableIPv6 bool
|
||||
|
||||
LazyConnectionEnabled bool
|
||||
|
||||
@@ -621,7 +625,7 @@ func (e *Engine) initFirewall() error {
|
||||
rosenpassPort := e.rpManager.GetAddress().Port
|
||||
port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}}
|
||||
|
||||
// this rule is static and will be torn down on engine down by the firewall manager
|
||||
// IPv4-only: rosenpass peers connect via AllowedIps[0] which is always v4.
|
||||
if _, err := e.firewall.AddPeerFiltering(
|
||||
nil,
|
||||
net.IP{0, 0, 0, 0},
|
||||
@@ -673,10 +677,15 @@ func (e *Engine) blockLanAccess() {
|
||||
|
||||
log.Infof("blocking route LAN access for networks: %v", toBlock)
|
||||
v4 := netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
v6 := netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
for _, network := range toBlock {
|
||||
source := v4
|
||||
if network.Addr().Is6() {
|
||||
source = v6
|
||||
}
|
||||
if _, err := e.firewall.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{v4},
|
||||
[]netip.Prefix{source},
|
||||
firewallManager.Network{Prefix: network},
|
||||
firewallManager.ProtocolALL,
|
||||
nil,
|
||||
@@ -714,7 +723,7 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if !compareNetIPLists(allowedIPs, p.GetAllowedIps()) {
|
||||
if !compareNetIPLists(allowedIPs, e.filterAllowedIPs(p.GetAllowedIps())) {
|
||||
modified = append(modified, p)
|
||||
continue
|
||||
}
|
||||
@@ -988,6 +997,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
e.config.DisableFirewall,
|
||||
e.config.BlockLANAccess,
|
||||
e.config.BlockInbound,
|
||||
e.config.DisableIPv6,
|
||||
e.config.LazyConnectionEnabled,
|
||||
e.config.EnableSSHRoot,
|
||||
e.config.EnableSSHSFTP,
|
||||
@@ -1015,6 +1025,13 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
return ErrResetConnection
|
||||
}
|
||||
|
||||
if !e.config.DisableIPv6 && e.hasIPv6Changed(conf) {
|
||||
log.Infof("peer IPv6 address changed, restarting client")
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
e.clientCancel()
|
||||
return ErrResetConnection
|
||||
}
|
||||
|
||||
if conf.GetSshConfig() != nil {
|
||||
if err := e.updateSSH(conf.GetSshConfig()); err != nil {
|
||||
log.Warnf("failed handling SSH server setup: %v", err)
|
||||
@@ -1023,6 +1040,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
|
||||
state := e.statusRecorder.GetLocalPeerState()
|
||||
state.IP = e.wgInterface.Address().String()
|
||||
state.IPv6 = e.wgInterface.Address().IPv6String()
|
||||
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
||||
state.KernelInterface = !e.wgInterface.IsUserspaceBind()
|
||||
state.FQDN = conf.GetFqdn()
|
||||
@@ -1031,6 +1049,28 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasIPv6Changed reports whether the IPv6 overlay address in the peer config
|
||||
// differs from the configured address (added, removed, or changed).
|
||||
// Compares against e.config.WgAddr (not the interface address, which may have
|
||||
// been cleared by ClearIPv6 if OS assignment failed).
|
||||
func (e *Engine) hasIPv6Changed(conf *mgmProto.PeerConfig) bool {
|
||||
current := e.config.WgAddr
|
||||
raw := conf.GetAddressV6()
|
||||
|
||||
if len(raw) == 0 {
|
||||
return current.HasIPv6()
|
||||
}
|
||||
|
||||
prefix, err := netiputil.DecodePrefix(raw)
|
||||
if err != nil {
|
||||
log.Warnf("decode v6 overlay address: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
return !current.HasIPv6() || current.IPv6 != prefix.Addr() || current.IPv6Net != prefix.Masked()
|
||||
}
|
||||
|
||||
func (e *Engine) receiveJobEvents() {
|
||||
e.jobExecutorWG.Add(1)
|
||||
go func() {
|
||||
@@ -1128,6 +1168,7 @@ func (e *Engine) receiveManagementEvents() {
|
||||
e.config.DisableFirewall,
|
||||
e.config.BlockLANAccess,
|
||||
e.config.BlockInbound,
|
||||
e.config.DisableIPv6,
|
||||
e.config.LazyConnectionEnabled,
|
||||
e.config.EnableSSHRoot,
|
||||
e.config.EnableSSHSFTP,
|
||||
@@ -1227,7 +1268,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||
}
|
||||
|
||||
dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)
|
||||
dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address())
|
||||
|
||||
if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
@@ -1382,7 +1423,9 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
|
||||
return entries
|
||||
}
|
||||
|
||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
|
||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, addr wgaddr.Address) nbdns.Config {
|
||||
network := addr.Network
|
||||
networkV6 := addr.IPv6Net
|
||||
//nolint
|
||||
forwarderPort := uint16(protoDNSConfig.GetForwarderPort())
|
||||
if forwarderPort == 0 {
|
||||
@@ -1439,6 +1482,9 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns
|
||||
|
||||
if len(dnsUpdate.CustomZones) > 0 {
|
||||
addReverseZone(&dnsUpdate, network)
|
||||
if networkV6.IsValid() {
|
||||
addReverseZone(&dnsUpdate, networkV6)
|
||||
}
|
||||
}
|
||||
|
||||
return dnsUpdate
|
||||
@@ -1448,8 +1494,10 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
|
||||
replacement := make([]peer.State, len(offlinePeers))
|
||||
for i, offlinePeer := range offlinePeers {
|
||||
log.Debugf("added offline peer %s", offlinePeer.Fqdn)
|
||||
v4, v6 := overlayAddrsFromAllowedIPs(offlinePeer.GetAllowedIps(), e.wgInterface.Address().IPv6Net)
|
||||
replacement[i] = peer.State{
|
||||
IP: strings.Join(offlinePeer.GetAllowedIps(), ","),
|
||||
IP: addrToString(v4),
|
||||
IPv6: addrToString(v6),
|
||||
PubKey: offlinePeer.GetWgPubKey(),
|
||||
FQDN: offlinePeer.GetFqdn(),
|
||||
ConnStatus: peer.StatusIdle,
|
||||
@@ -1460,6 +1508,37 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
|
||||
e.statusRecorder.ReplaceOfflinePeers(replacement)
|
||||
}
|
||||
|
||||
// overlayAddrsFromAllowedIPs extracts the peer's v4 and v6 overlay addresses
|
||||
// from AllowedIPs strings. Only host routes (/32, /128) are considered; v6 must
|
||||
// fall within ourV6Net to distinguish overlay addresses from routed prefixes.
|
||||
func overlayAddrsFromAllowedIPs(allowedIPs []string, ourV6Net netip.Prefix) (v4, v6 netip.Addr) {
|
||||
for _, cidr := range allowedIPs {
|
||||
prefix, err := netip.ParsePrefix(cidr)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse AllowedIP %q: %v", cidr, err)
|
||||
continue
|
||||
}
|
||||
addr := prefix.Addr().Unmap()
|
||||
switch {
|
||||
case addr.Is4() && prefix.Bits() == 32 && !v4.IsValid():
|
||||
v4 = addr
|
||||
case addr.Is6() && prefix.Bits() == 128 && ourV6Net.Contains(addr) && !v6.IsValid():
|
||||
v6 = addr
|
||||
}
|
||||
if v4.IsValid() && v6.IsValid() {
|
||||
break
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func addrToString(addr netip.Addr) string {
|
||||
if !addr.IsValid() {
|
||||
return ""
|
||||
}
|
||||
return addr.String()
|
||||
}
|
||||
|
||||
// addNewPeers adds peers that were not know before but arrived from the Management service with the update
|
||||
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
for _, p := range peersUpdate {
|
||||
@@ -1485,15 +1564,23 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
||||
log.Errorf("failed to parse allowedIPS: %v", err)
|
||||
return err
|
||||
}
|
||||
if allowedNetIP.Addr().Is6() && !e.wgInterface.Address().HasIPv6() {
|
||||
continue
|
||||
}
|
||||
peerIPs = append(peerIPs, allowedNetIP)
|
||||
}
|
||||
|
||||
if len(peerIPs) == 0 {
|
||||
return fmt.Errorf("peer %s has no usable AllowedIPs", peerKey)
|
||||
}
|
||||
|
||||
conn, err := e.createPeerConn(peerKey, peerIPs, peerConfig.AgentVersion)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create peer connection: %w", err)
|
||||
}
|
||||
|
||||
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn, peerIPs[0].Addr().String())
|
||||
peerV4, peerV6 := overlayAddrsFromAllowedIPs(peerConfig.GetAllowedIps(), e.wgInterface.Address().IPv6Net)
|
||||
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn, addrToString(peerV4), addrToString(peerV6))
|
||||
if err != nil {
|
||||
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
||||
}
|
||||
@@ -1716,6 +1803,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
||||
e.config.DisableFirewall,
|
||||
e.config.BlockLANAccess,
|
||||
e.config.BlockInbound,
|
||||
e.config.DisableIPv6,
|
||||
e.config.LazyConnectionEnabled,
|
||||
e.config.EnableSSHRoot,
|
||||
e.config.EnableSSHSFTP,
|
||||
@@ -1729,7 +1817,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
||||
return nil, nil, false, err
|
||||
}
|
||||
routes := toRoutes(netMap.GetRoutes())
|
||||
dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network)
|
||||
dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address())
|
||||
dnsFeatureFlag := toDNSFeatureFlag(netMap)
|
||||
return routes, &dnsCfg, dnsFeatureFlag, nil
|
||||
}
|
||||
@@ -1771,7 +1859,8 @@ func (e *Engine) wgInterfaceCreate() (err error) {
|
||||
case "android":
|
||||
err = e.wgInterface.CreateOnAndroid(e.routeManager.InitialRouteRange(), e.dnsServer.DnsIP().String(), e.dnsServer.SearchDomains())
|
||||
case "ios":
|
||||
e.mobileDep.NetworkChangeListener.SetInterfaceIP(e.config.WgAddr)
|
||||
e.mobileDep.NetworkChangeListener.SetInterfaceIP(e.config.WgAddr.String())
|
||||
e.mobileDep.NetworkChangeListener.SetInterfaceIPv6(e.config.WgAddr.IPv6String())
|
||||
err = e.wgInterface.Create()
|
||||
default:
|
||||
err = e.wgInterface.Create()
|
||||
@@ -2268,8 +2357,7 @@ func getInterfacePrefixes() ([]netip.Prefix, error) {
|
||||
prefix := netip.PrefixFrom(addr.Unmap(), ones).Masked()
|
||||
ip := prefix.Addr()
|
||||
|
||||
// TODO: add IPv6
|
||||
if !ip.Is4() || ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||
if ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -2280,6 +2368,24 @@ func getInterfacePrefixes() ([]netip.Prefix, error) {
|
||||
return prefixes, nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// filterAllowedIPs strips IPv6 entries when the local interface has no v6 address.
|
||||
// This covers both the explicit --disable-ipv6 flag (v6 never assigned) and the
|
||||
// case where OS v6 assignment failed (ClearIPv6). Without this, WireGuard would
|
||||
// accept v6 traffic that the native firewall cannot filter.
|
||||
func (e *Engine) filterAllowedIPs(ips []string) []string {
|
||||
if e.wgInterface.Address().HasIPv6() {
|
||||
return ips
|
||||
}
|
||||
filtered := make([]string, 0, len(ips))
|
||||
for _, s := range ips {
|
||||
p, err := netip.ParsePrefix(s)
|
||||
if err != nil || !p.Addr().Is6() {
|
||||
filtered = append(filtered, s)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// compareNetIPLists compares a list of netip.Prefix with a list of strings.
|
||||
// return true if both lists are equal, false otherwise.
|
||||
func compareNetIPLists(list1 []netip.Prefix, list2 []string) bool {
|
||||
|
||||
@@ -41,6 +41,14 @@ func (e *Engine) setupSSHPortRedirection() error {
|
||||
}
|
||||
log.Infof("SSH port redirection enabled: %s:22 -> %s:22022", localAddr, localAddr)
|
||||
|
||||
if v6 := e.wgInterface.Address().IPv6; v6.IsValid() {
|
||||
if err := e.firewall.AddInboundDNAT(v6, firewallManager.ProtocolTCP, 22, 22022); err != nil {
|
||||
log.Warnf("failed to add IPv6 SSH port redirection: %v", err)
|
||||
} else {
|
||||
log.Infof("SSH port redirection enabled: [%s]:22 -> [%s]:22022", v6, v6)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -137,12 +145,13 @@ func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []
|
||||
continue
|
||||
}
|
||||
|
||||
peerIP := e.extractPeerIP(peerConfig)
|
||||
peerV4, peerV6 := overlayAddrsFromAllowedIPs(peerConfig.GetAllowedIps(), e.wgInterface.Address().IPv6Net)
|
||||
hostname := e.extractHostname(peerConfig)
|
||||
|
||||
peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{
|
||||
Hostname: hostname,
|
||||
IP: peerIP,
|
||||
IP: peerV4,
|
||||
IPv6: peerV6,
|
||||
FQDN: peerConfig.GetFqdn(),
|
||||
})
|
||||
}
|
||||
@@ -150,18 +159,6 @@ func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []
|
||||
return peerInfo
|
||||
}
|
||||
|
||||
// extractPeerIP extracts IP address from peer's allowed IPs
|
||||
func (e *Engine) extractPeerIP(peerConfig *mgmProto.RemotePeerConfig) string {
|
||||
if len(peerConfig.GetAllowedIps()) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if prefix, err := netip.ParsePrefix(peerConfig.GetAllowedIps()[0]); err == nil {
|
||||
return prefix.Addr().String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractHostname extracts short hostname from FQDN
|
||||
func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string {
|
||||
fqdn := peerConfig.GetFqdn()
|
||||
@@ -208,7 +205,7 @@ func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
|
||||
|
||||
fullStatus := statusRecorder.GetFullStatus()
|
||||
for _, peerState := range fullStatus.Peers {
|
||||
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
|
||||
if peerState.IP == peerAddress || peerState.FQDN == peerAddress || peerState.IPv6 == peerAddress {
|
||||
if len(peerState.SSHHostKey) > 0 {
|
||||
return peerState.SSHHostKey, true
|
||||
}
|
||||
@@ -262,6 +259,13 @@ func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
|
||||
return fmt.Errorf("start SSH server: %w", err)
|
||||
}
|
||||
|
||||
if v6 := wgAddr.IPv6; v6.IsValid() {
|
||||
v6Addr := netip.AddrPortFrom(v6, sshserver.InternalSSHPort)
|
||||
if err := server.AddListener(e.ctx, v6Addr); err != nil {
|
||||
log.Warnf("failed to add IPv6 SSH listener: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
e.sshServer = server
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
@@ -330,6 +334,12 @@ func (e *Engine) cleanupSSHPortRedirection() error {
|
||||
}
|
||||
log.Debugf("SSH port redirection removed: %s:22 -> %s:22022", localAddr, localAddr)
|
||||
|
||||
if v6 := e.wgInterface.Address().IPv6; v6.IsValid() {
|
||||
if err := e.firewall.RemoveInboundDNAT(v6, firewallManager.ProtocolTCP, 22, 22022); err != nil {
|
||||
log.Debugf("failed to remove IPv6 SSH port redirection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -66,6 +66,7 @@ import (
|
||||
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
@@ -94,7 +95,7 @@ type MockWGIface struct {
|
||||
AddressFunc func() wgaddr.Address
|
||||
ToInterfaceFunc func() *net.Interface
|
||||
UpFunc func() (*udpmux.UniversalUDPMuxDefault, error)
|
||||
UpdateAddrFunc func(newAddr string) error
|
||||
UpdateAddrFunc func(newAddr wgaddr.Address) error
|
||||
UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeerFunc func(peerKey string) error
|
||||
AddAllowedIPFunc func(peerKey string, allowedIP netip.Prefix) error
|
||||
@@ -156,7 +157,7 @@ func (m *MockWGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
return m.UpFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) UpdateAddr(newAddr string) error {
|
||||
func (m *MockWGIface) UpdateAddr(newAddr wgaddr.Address) error {
|
||||
return m.UpdateAddrFunc(newAddr)
|
||||
}
|
||||
|
||||
@@ -253,7 +254,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
ctx, cancel,
|
||||
&EngineConfig{
|
||||
WgIfaceName: "utun101",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
ServerSSHAllowed: true,
|
||||
@@ -430,7 +431,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||
WgIfaceName: "utun102",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
@@ -654,7 +655,7 @@ func TestEngine_Sync(t *testing.T) {
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||
WgIfaceName: "utun103",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
@@ -824,7 +825,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||
WgIfaceName: wgIfaceName,
|
||||
WgAddr: wgAddr,
|
||||
WgAddr: wgaddr.MustParseWGAddress(wgAddr),
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
@@ -842,7 +843,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: wgIfaceName,
|
||||
Address: wgAddr,
|
||||
Address: wgaddr.MustParseWGAddress(wgAddr),
|
||||
WGPort: engine.config.WgPort,
|
||||
WGPrivKey: key.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
@@ -1031,7 +1032,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||
WgIfaceName: wgIfaceName,
|
||||
WgAddr: wgAddr,
|
||||
WgAddr: wgaddr.MustParseWGAddress(wgAddr),
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
@@ -1049,7 +1050,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
}
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: wgIfaceName,
|
||||
Address: wgAddr,
|
||||
Address: wgaddr.MustParseWGAddress(wgAddr),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: key.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
@@ -1559,7 +1560,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
wgPort := 33100 + i
|
||||
conf := &EngineConfig{
|
||||
WgIfaceName: ifaceName,
|
||||
WgAddr: resp.PeerConfig.Address,
|
||||
WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address),
|
||||
WgPrivateKey: key,
|
||||
WgPort: wgPort,
|
||||
MTU: iface.DefaultMTU,
|
||||
@@ -1704,3 +1705,217 @@ func getPeers(e *Engine) int {
|
||||
|
||||
return len(e.peerStore.PeersPubKey())
|
||||
}
|
||||
|
||||
func TestEngine_hasIPv6Changed(t *testing.T) {
|
||||
v4Only := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
|
||||
v4v6 := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
v4v6.IPv6 = netip.MustParseAddr("fd00::1")
|
||||
v4v6.IPv6Net = netip.MustParsePrefix("fd00::1/64").Masked()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
current wgaddr.Address
|
||||
confV6 []byte
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "no v6 before, no v6 now",
|
||||
current: v4Only,
|
||||
confV6: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "no v6 before, v6 added",
|
||||
current: v4Only,
|
||||
confV6: netiputil.EncodePrefix(netip.MustParsePrefix("fd00::1/64")),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "had v6, now removed",
|
||||
current: v4v6,
|
||||
confV6: nil,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "had v6, same v6",
|
||||
current: v4v6,
|
||||
confV6: netiputil.EncodePrefix(netip.MustParsePrefix("fd00::1/64")),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "had v6, different v6",
|
||||
current: v4v6,
|
||||
confV6: netiputil.EncodePrefix(netip.MustParsePrefix("fd00::2/64")),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "same v6 addr, different prefix length",
|
||||
current: v4v6,
|
||||
confV6: netiputil.EncodePrefix(netip.MustParsePrefix("fd00::1/80")),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "decode error keeps status quo",
|
||||
current: v4Only,
|
||||
confV6: []byte{1, 2, 3},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
engine := &Engine{
|
||||
config: &EngineConfig{WgAddr: tt.current},
|
||||
}
|
||||
conf := &mgmtProto.PeerConfig{
|
||||
AddressV6: tt.confV6,
|
||||
}
|
||||
assert.Equal(t, tt.expected, engine.hasIPv6Changed(conf))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterAllowedIPs(t *testing.T) {
|
||||
v4v6Addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
v4v6Addr.IPv6 = netip.MustParseAddr("fd00::1")
|
||||
v4v6Addr.IPv6Net = netip.MustParsePrefix("fd00::1/64").Masked()
|
||||
|
||||
v4OnlyAddr := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
addr wgaddr.Address
|
||||
input []string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "interface has v6, keep all",
|
||||
addr: v4v6Addr,
|
||||
input: []string{"100.64.0.1/32", "fd00::1/128"},
|
||||
expected: []string{"100.64.0.1/32", "fd00::1/128"},
|
||||
},
|
||||
{
|
||||
name: "no v6, strip v6",
|
||||
addr: v4OnlyAddr,
|
||||
input: []string{"100.64.0.1/32", "fd00::1/128"},
|
||||
expected: []string{"100.64.0.1/32"},
|
||||
},
|
||||
{
|
||||
name: "no v6, only v4",
|
||||
addr: v4OnlyAddr,
|
||||
input: []string{"100.64.0.1/32", "10.0.0.0/8"},
|
||||
expected: []string{"100.64.0.1/32", "10.0.0.0/8"},
|
||||
},
|
||||
{
|
||||
name: "no v6, only v6 input",
|
||||
addr: v4OnlyAddr,
|
||||
input: []string{"fd00::1/128", "::/0"},
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "no v6, invalid prefix preserved",
|
||||
addr: v4OnlyAddr,
|
||||
input: []string{"100.64.0.1/32", "garbage"},
|
||||
expected: []string{"100.64.0.1/32", "garbage"},
|
||||
},
|
||||
{
|
||||
name: "no v6, empty input",
|
||||
addr: v4OnlyAddr,
|
||||
input: []string{},
|
||||
expected: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr := tt.addr
|
||||
engine := &Engine{
|
||||
config: &EngineConfig{},
|
||||
wgInterface: &MockWGIface{
|
||||
AddressFunc: func() wgaddr.Address { return addr },
|
||||
},
|
||||
}
|
||||
result := engine.filterAllowedIPs(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOverlayAddrsFromAllowedIPs(t *testing.T) {
|
||||
ourV6Net := netip.MustParsePrefix("fd00:1234:5678:abcd::/64")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
allowedIPs []string
|
||||
ourV6Net netip.Prefix
|
||||
wantV4 string
|
||||
wantV6 string
|
||||
}{
|
||||
{
|
||||
name: "v4 only",
|
||||
allowedIPs: []string{"100.64.0.1/32"},
|
||||
ourV6Net: ourV6Net,
|
||||
wantV4: "100.64.0.1",
|
||||
wantV6: "",
|
||||
},
|
||||
{
|
||||
name: "v4 and v6 overlay",
|
||||
allowedIPs: []string{"100.64.0.1/32", "fd00:1234:5678:abcd::1/128"},
|
||||
ourV6Net: ourV6Net,
|
||||
wantV4: "100.64.0.1",
|
||||
wantV6: "fd00:1234:5678:abcd::1",
|
||||
},
|
||||
{
|
||||
name: "v4, routed v6, overlay v6",
|
||||
allowedIPs: []string{"100.64.0.1/32", "2001:db8::1/128", "fd00:1234:5678:abcd::1/128"},
|
||||
ourV6Net: ourV6Net,
|
||||
wantV4: "100.64.0.1",
|
||||
wantV6: "fd00:1234:5678:abcd::1",
|
||||
},
|
||||
{
|
||||
name: "routed v6 /128 outside our subnet is ignored",
|
||||
allowedIPs: []string{"100.64.0.1/32", "2001:db8::1/128"},
|
||||
ourV6Net: ourV6Net,
|
||||
wantV4: "100.64.0.1",
|
||||
wantV6: "",
|
||||
},
|
||||
{
|
||||
name: "routed v6 prefix is ignored",
|
||||
allowedIPs: []string{"100.64.0.1/32", "fd00:1234:5678:abcd::/64"},
|
||||
ourV6Net: ourV6Net,
|
||||
wantV4: "100.64.0.1",
|
||||
wantV6: "",
|
||||
},
|
||||
{
|
||||
name: "no v6 subnet configured",
|
||||
allowedIPs: []string{"100.64.0.1/32", "fd00:1234:5678:abcd::1/128"},
|
||||
ourV6Net: netip.Prefix{},
|
||||
wantV4: "100.64.0.1",
|
||||
wantV6: "",
|
||||
},
|
||||
{
|
||||
name: "v4 /24 route is ignored",
|
||||
allowedIPs: []string{"100.64.0.0/24", "100.64.0.1/32"},
|
||||
ourV6Net: ourV6Net,
|
||||
wantV4: "100.64.0.1",
|
||||
wantV6: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v4, v6 := overlayAddrsFromAllowedIPs(tt.allowedIPs, tt.ourV6Net)
|
||||
if tt.wantV4 == "" {
|
||||
assert.False(t, v4.IsValid(), "expected no v4")
|
||||
} else {
|
||||
assert.Equal(t, tt.wantV4, v4.String(), "v4")
|
||||
}
|
||||
if tt.wantV6 == "" {
|
||||
assert.False(t, v6.IsValid(), "expected no v6")
|
||||
} else {
|
||||
assert.Equal(t, tt.wantV6, v6.String(), "v6")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ type wgIfaceBase interface {
|
||||
Address() wgaddr.Address
|
||||
ToInterface() *net.Interface
|
||||
Up() (*udpmux.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(newAddr string) error
|
||||
UpdateAddr(newAddr wgaddr.Address) error
|
||||
GetProxy() wgproxy.Proxy
|
||||
GetProxyPort() uint16
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
|
||||
@@ -57,6 +57,7 @@ func NewBindListener(wgIface WgInterface, bind device.EndpointManager, cfg lazyc
|
||||
// deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP.
|
||||
// Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y).
|
||||
// It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface.
|
||||
// For IPv6-only peers, the last two bytes of the v6 address are used.
|
||||
func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) {
|
||||
if len(allowedIPs) == 0 {
|
||||
return netip.Addr{}, fmt.Errorf("no allowed IPs for peer")
|
||||
@@ -64,6 +65,7 @@ func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, e
|
||||
|
||||
ourNetwork := wgIface.Address().Network
|
||||
|
||||
// Try v4 first (preferred: deterministic from overlay IP)
|
||||
var peerIP netip.Addr
|
||||
for _, allowedIP := range allowedIPs {
|
||||
ip := allowedIP.Addr()
|
||||
@@ -76,13 +78,24 @@ func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, e
|
||||
}
|
||||
}
|
||||
|
||||
if !peerIP.IsValid() {
|
||||
return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs")
|
||||
if peerIP.IsValid() {
|
||||
octets := peerIP.As4()
|
||||
return netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]}), nil
|
||||
}
|
||||
|
||||
octets := peerIP.As4()
|
||||
fakeIP := netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]})
|
||||
return fakeIP, nil
|
||||
// Fallback: use last two bytes of first v6 overlay IP
|
||||
addr := wgIface.Address()
|
||||
if addr.IPv6Net.IsValid() {
|
||||
for _, allowedIP := range allowedIPs {
|
||||
ip := allowedIP.Addr()
|
||||
if ip.Is6() && addr.IPv6Net.Contains(ip) {
|
||||
raw := ip.As16()
|
||||
return netip.AddrFrom4([4]byte{127, 2, raw[14], raw[15]}), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs")
|
||||
}
|
||||
|
||||
func (d *BindListener) setupLazyConn() error {
|
||||
|
||||
@@ -5,4 +5,5 @@ type NetworkChangeListener interface {
|
||||
// OnNetworkChanged invoke when network settings has been changed
|
||||
OnNetworkChanged(string)
|
||||
SetInterfaceIP(string)
|
||||
SetInterfaceIPv6(string)
|
||||
}
|
||||
|
||||
@@ -188,7 +188,7 @@ func (c *ConnTrack) handleEvent(event nfct.Event) {
|
||||
case nftypes.TCP, nftypes.UDP, nftypes.SCTP:
|
||||
srcPort = flow.TupleOrig.Proto.SourcePort
|
||||
dstPort = flow.TupleOrig.Proto.DestinationPort
|
||||
case nftypes.ICMP:
|
||||
case nftypes.ICMP, nftypes.ICMPv6:
|
||||
icmpType = flow.TupleOrig.Proto.ICMPType
|
||||
icmpCode = flow.TupleOrig.Proto.ICMPCode
|
||||
}
|
||||
@@ -231,8 +231,14 @@ func (c *ConnTrack) relevantFlow(mark uint32, srcIP, dstIP netip.Addr) bool {
|
||||
}
|
||||
|
||||
// fallback if mark rules are not in place
|
||||
wgnet := c.iface.Address().Network
|
||||
return wgnet.Contains(srcIP) || wgnet.Contains(dstIP)
|
||||
addr := c.iface.Address()
|
||||
if addr.Network.Contains(srcIP) || addr.Network.Contains(dstIP) {
|
||||
return true
|
||||
}
|
||||
if addr.IPv6Net.IsValid() {
|
||||
return addr.IPv6Net.Contains(srcIP) || addr.IPv6Net.Contains(dstIP)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// mapRxPackets maps packet counts to RX based on flow direction
|
||||
@@ -291,17 +297,16 @@ func (c *ConnTrack) inferDirection(mark uint32, srcIP, dstIP netip.Addr) nftypes
|
||||
}
|
||||
|
||||
// fallback if marks are not set
|
||||
wgaddr := c.iface.Address().IP
|
||||
wgnetwork := c.iface.Address().Network
|
||||
addr := c.iface.Address()
|
||||
switch {
|
||||
case wgaddr == srcIP:
|
||||
case addr.IP == srcIP || (addr.IPv6.IsValid() && addr.IPv6 == srcIP):
|
||||
return nftypes.Egress
|
||||
case wgaddr == dstIP:
|
||||
case addr.IP == dstIP || (addr.IPv6.IsValid() && addr.IPv6 == dstIP):
|
||||
return nftypes.Ingress
|
||||
case wgnetwork.Contains(srcIP):
|
||||
case addr.Network.Contains(srcIP) || (addr.IPv6Net.IsValid() && addr.IPv6Net.Contains(srcIP)):
|
||||
// netbird network -> resource network
|
||||
return nftypes.Ingress
|
||||
case wgnetwork.Contains(dstIP):
|
||||
case addr.Network.Contains(dstIP) || (addr.IPv6Net.IsValid() && addr.IPv6Net.Contains(dstIP)):
|
||||
// resource network -> netbird network
|
||||
return nftypes.Egress
|
||||
}
|
||||
|
||||
@@ -24,15 +24,17 @@ type Logger struct {
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
wgIfaceNet netip.Prefix
|
||||
wgIfaceNetV6 netip.Prefix
|
||||
dnsCollection atomic.Bool
|
||||
exitNodeCollection atomic.Bool
|
||||
Store types.Store
|
||||
}
|
||||
|
||||
func New(statusRecorder *peer.Status, wgIfaceIPNet netip.Prefix) *Logger {
|
||||
func New(statusRecorder *peer.Status, wgIfaceIPNet, wgIfaceIPNetV6 netip.Prefix) *Logger {
|
||||
return &Logger{
|
||||
statusRecorder: statusRecorder,
|
||||
wgIfaceNet: wgIfaceIPNet,
|
||||
wgIfaceNetV6: wgIfaceIPNetV6,
|
||||
Store: store.NewMemoryStore(),
|
||||
}
|
||||
}
|
||||
@@ -88,11 +90,11 @@ func (l *Logger) startReceiver() {
|
||||
var isSrcExitNode bool
|
||||
var isDestExitNode bool
|
||||
|
||||
if !l.wgIfaceNet.Contains(event.SourceIP) {
|
||||
if !l.isOverlayIP(event.SourceIP) {
|
||||
event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP)
|
||||
}
|
||||
|
||||
if !l.wgIfaceNet.Contains(event.DestIP) {
|
||||
if !l.isOverlayIP(event.DestIP) {
|
||||
event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP)
|
||||
}
|
||||
|
||||
@@ -136,6 +138,10 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) {
|
||||
l.exitNodeCollection.Store(exitNodeCollection)
|
||||
}
|
||||
|
||||
func (l *Logger) isOverlayIP(ip netip.Addr) bool {
|
||||
return l.wgIfaceNet.Contains(ip) || (l.wgIfaceNetV6.IsValid() && l.wgIfaceNetV6.Contains(ip))
|
||||
}
|
||||
|
||||
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
|
||||
// check dns collection
|
||||
if !l.dnsCollection.Load() && event.Protocol == types.UDP &&
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func TestStore(t *testing.T) {
|
||||
logger := logger.New(nil, netip.Prefix{})
|
||||
logger := logger.New(nil, netip.Prefix{}, netip.Prefix{})
|
||||
logger.Enable()
|
||||
|
||||
event := types.EventFields{
|
||||
|
||||
@@ -35,11 +35,12 @@ type Manager struct {
|
||||
|
||||
// NewManager creates a new netflow manager
|
||||
func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
|
||||
var prefix netip.Prefix
|
||||
var prefix, prefixV6 netip.Prefix
|
||||
if iface != nil {
|
||||
prefix = iface.Address().Network
|
||||
prefixV6 = iface.Address().IPv6Net
|
||||
}
|
||||
flowLogger := logger.New(statusRecorder, prefix)
|
||||
flowLogger := logger.New(statusRecorder, prefix, prefixV6)
|
||||
|
||||
var ct nftypes.ConnTracker
|
||||
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
|
||||
@@ -269,7 +270,7 @@ func toProtoEvent(publicKey []byte, event *nftypes.Event) *proto.FlowEvent {
|
||||
},
|
||||
}
|
||||
|
||||
if event.Protocol == nftypes.ICMP {
|
||||
if event.Protocol == nftypes.ICMP || event.Protocol == nftypes.ICMPv6 {
|
||||
protoEvent.FlowFields.ConnectionInfo = &proto.FlowFields_IcmpInfo{
|
||||
IcmpInfo: &proto.ICMPInfo{
|
||||
IcmpType: uint32(event.ICMPType),
|
||||
|
||||
@@ -19,6 +19,7 @@ const (
|
||||
ICMP = Protocol(1)
|
||||
TCP = Protocol(6)
|
||||
UDP = Protocol(17)
|
||||
ICMPv6 = Protocol(58)
|
||||
SCTP = Protocol(132)
|
||||
)
|
||||
|
||||
@@ -30,6 +31,8 @@ func (p Protocol) String() string {
|
||||
return "TCP"
|
||||
case 17:
|
||||
return "UDP"
|
||||
case 58:
|
||||
return "ICMPv6"
|
||||
case 132:
|
||||
return "SCTP"
|
||||
default:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user