mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-02 15:43:47 -04:00
Compare commits
78 Commits
feature/ex
...
refactor/n
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
575b176371 | ||
|
|
b96dfb834e | ||
|
|
a654c286bf | ||
|
|
20c65cffc2 | ||
|
|
95cf5ab673 | ||
|
|
ed0c5a470a | ||
|
|
3a11175d34 | ||
|
|
d82176e9ea | ||
|
|
56ebd004d9 | ||
|
|
1ee575befe | ||
|
|
d3a34adcc9 | ||
|
|
bbb0910a2c | ||
|
|
ec035e5709 | ||
|
|
d7321c130b | ||
|
|
404cab90ba | ||
|
|
4545ab9a52 | ||
|
|
7f08983207 | ||
|
|
eddea14521 | ||
|
|
b9ef214ea5 | ||
|
|
709e24eb6f | ||
|
|
6654e2dbf7 | ||
|
|
d80d47a469 | ||
|
|
96f71ff1e1 | ||
|
|
2fe2af38d2 | ||
|
|
d0f8868e4d | ||
|
|
31e3928e08 | ||
|
|
1793ad8ff0 | ||
|
|
7e51803a58 | ||
|
|
4cab69a76c | ||
|
|
dace4ff30c | ||
|
|
069586b77f | ||
|
|
dac6fb16ba | ||
|
|
cd7dce7678 | ||
|
|
9bdabfbb2c | ||
|
|
2760c78bc2 | ||
|
|
2882638c7f | ||
|
|
9c7c15ba8b | ||
|
|
b0dcc9ff7b | ||
|
|
11e2573747 | ||
|
|
6ac9e58911 | ||
|
|
c115434d53 | ||
|
|
3984f0ee08 | ||
|
|
2dc5e7eb7a | ||
|
|
75dc7f6517 | ||
|
|
409e2013cc | ||
|
|
f13c6ba1b8 | ||
|
|
fc2f23bd22 | ||
|
|
26f5aee4b9 | ||
|
|
8d47175cdb | ||
|
|
274711a37e | ||
|
|
53e24ae7f7 | ||
|
|
fbc02343e9 | ||
|
|
ffed4b38ef | ||
|
|
e926ca34b5 | ||
|
|
5d1c61369d | ||
|
|
fd9e21a5f3 | ||
|
|
841bc7564a | ||
|
|
f20a1b3328 | ||
|
|
2ac0da6cac | ||
|
|
148b8b04b3 | ||
|
|
9a56883ffb | ||
|
|
806be13dd5 | ||
|
|
90557da237 | ||
|
|
1d6209841e | ||
|
|
8f0e5708d5 | ||
|
|
5a9aa55121 | ||
|
|
6082c7cdcb | ||
|
|
06eae13352 | ||
|
|
08fba9876b | ||
|
|
ca85aa9b8f | ||
|
|
0ae2241573 | ||
|
|
050c05164a | ||
|
|
333908d06e | ||
|
|
bc6c5ece6e | ||
|
|
fd7b3ae21c | ||
|
|
abd7a84a46 | ||
|
|
f4b2bed1b9 | ||
|
|
2fb971e88a |
@@ -168,7 +168,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{})
|
||||
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{ShouldRunProbes: true})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
|
||||
}
|
||||
@@ -303,12 +303,18 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
|
||||
|
||||
func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
||||
var statusOutputString string
|
||||
statusResp, err := getStatus(cmd.Context())
|
||||
statusResp, err := getStatus(cmd.Context(), true)
|
||||
if err != nil {
|
||||
cmd.PrintErrf("Failed to get status: %v\n", err)
|
||||
} else {
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""),
|
||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
|
||||
)
|
||||
}
|
||||
return statusOutputString
|
||||
|
||||
@@ -68,7 +68,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
resp, err := getStatus(ctx)
|
||||
resp, err := getStatus(ctx, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -121,7 +121,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
||||
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
@@ -130,7 +130,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
@@ -260,6 +260,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return m.router.UpdateSet(set, prefixes)
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
func getConntrackEstablished() []string {
|
||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
@@ -880,6 +880,54 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
dnatRule := []string{
|
||||
"-i", r.wgIface.Name(),
|
||||
"-p", strings.ToLower(string(protocol)),
|
||||
"--dport", strconv.Itoa(int(sourcePort)),
|
||||
"-d", localAddr.String(),
|
||||
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||
"-j", "DNAT",
|
||||
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
|
||||
}
|
||||
|
||||
ruleInfo := ruleInfo{
|
||||
table: tableNat,
|
||||
chain: chainRTRDR,
|
||||
rule: dnatRule,
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
||||
return fmt.Errorf("add inbound DNAT rule: %w", err)
|
||||
}
|
||||
r.rules[ruleID] = ruleInfo.rule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if dnatRule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
|
||||
return fmt.Errorf("delete inbound DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyPort(flag string, port *firewall.Port) []string {
|
||||
if port == nil {
|
||||
return nil
|
||||
|
||||
@@ -151,14 +151,20 @@ type Manager interface {
|
||||
|
||||
DisableRouting() error
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
// AddDNATRule adds outbound DNAT rule for forwarding external traffic to the NetBird network.
|
||||
AddDNATRule(ForwardRule) (Rule, error)
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
// DeleteDNATRule deletes the outbound DNAT rule.
|
||||
DeleteDNATRule(Rule) error
|
||||
|
||||
// UpdateSet updates the set with the given prefixes
|
||||
UpdateSet(hash Set, prefixes []netip.Prefix) error
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
|
||||
AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// RemoveInboundDNAT removes inbound DNAT rule
|
||||
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
|
||||
@@ -376,6 +376,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return m.router.UpdateSet(set, prefixes)
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
|
||||
@@ -1350,6 +1350,103 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
protoNum, err := protoToInt(protocol)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 2,
|
||||
Data: []byte{protoNum},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 3,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 3,
|
||||
Data: binaryutil.BigEndian.PutUint16(sourcePort),
|
||||
},
|
||||
}
|
||||
|
||||
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: localAddr.AsSlice(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(targetPort),
|
||||
},
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(nftables.TableFamilyIPv4),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: 2,
|
||||
RegProtoMax: 0,
|
||||
},
|
||||
)
|
||||
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingRdr],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleID),
|
||||
}
|
||||
r.conn.AddRule(dnatRule)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("add inbound DNAT rule: %w", err)
|
||||
}
|
||||
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if rule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyNetwork generates nftables expressions for networks (CIDR) or sets
|
||||
func (r *router) applyNetwork(
|
||||
network firewall.Network,
|
||||
|
||||
@@ -22,6 +22,8 @@ type BaseConnTrack struct {
|
||||
PacketsRx atomic.Uint64
|
||||
BytesTx atomic.Uint64
|
||||
BytesRx atomic.Uint64
|
||||
|
||||
DNATOrigPort atomic.Uint32
|
||||
}
|
||||
|
||||
// these small methods will be inlined by the compiler
|
||||
|
||||
@@ -157,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
return tracker
|
||||
}
|
||||
|
||||
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, uint16, bool) {
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
@@ -171,28 +171,30 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
||||
|
||||
if exists {
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
return key, true
|
||||
return key, uint16(conn.DNATOrigPort.Load()), true
|
||||
}
|
||||
|
||||
return key, false
|
||||
return key, 0, false
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound TCP connection
|
||||
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
|
||||
// TrackOutbound records an outbound TCP connection and returns the original port if DNAT reversal is needed
|
||||
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) uint16 {
|
||||
if _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); exists {
|
||||
return origPort
|
||||
}
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size, 0)
|
||||
return 0
|
||||
}
|
||||
|
||||
// TrackInbound processes an inbound TCP packet and updates connection state
|
||||
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
|
||||
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int, dnatOrigPort uint16) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size, dnatOrigPort)
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound connections
|
||||
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
|
||||
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) {
|
||||
key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
|
||||
if exists || flags&TCPSyn == 0 {
|
||||
return
|
||||
}
|
||||
@@ -210,8 +212,13 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
|
||||
|
||||
conn.tombstone.Store(false)
|
||||
conn.state.Store(int32(TCPStateNew))
|
||||
conn.DNATOrigPort.Store(uint32(origPort))
|
||||
|
||||
t.logger.Trace2("New %s TCP connection: %s", direction, key)
|
||||
if origPort != 0 {
|
||||
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
|
||||
} else {
|
||||
t.logger.Trace2("New %s TCP connection: %s", direction, key)
|
||||
}
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
|
||||
t.mutex.Lock()
|
||||
@@ -449,6 +456,21 @@ func (t *TCPTracker) cleanup() {
|
||||
}
|
||||
}
|
||||
|
||||
// GetConnection safely retrieves a connection state
|
||||
func (t *TCPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*TCPConnTrack, bool) {
|
||||
t.mutex.RLock()
|
||||
defer t.mutex.RUnlock()
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn, exists := t.connections[key]
|
||||
return conn, exists
|
||||
}
|
||||
|
||||
// Close stops the cleanup routine and releases resources
|
||||
func (t *TCPTracker) Close() {
|
||||
t.tickerCancel()
|
||||
|
||||
@@ -603,7 +603,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
|
||||
serverPort := uint16(80)
|
||||
|
||||
// 1. Client sends SYN (we receive it as inbound)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: clientIP,
|
||||
@@ -623,12 +623,12 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||
|
||||
// 3. Client sends ACK to complete handshake
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
|
||||
|
||||
// 4. Test data transfer
|
||||
// Client sends data
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000, 0)
|
||||
|
||||
// Server sends ACK for data
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
|
||||
@@ -637,7 +637,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
|
||||
|
||||
// Client sends ACK for data
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||
|
||||
// Verify state and counters
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
|
||||
@@ -58,20 +58,23 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
return tracker
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound UDP connection
|
||||
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
|
||||
// TrackOutbound records an outbound UDP connection and returns the original port if DNAT reversal is needed
|
||||
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) uint16 {
|
||||
_, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size)
|
||||
if exists {
|
||||
return origPort
|
||||
}
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size, 0)
|
||||
return 0
|
||||
}
|
||||
|
||||
// TrackInbound records an inbound UDP connection
|
||||
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
|
||||
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int, dnatOrigPort uint16) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size, dnatOrigPort)
|
||||
}
|
||||
|
||||
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, uint16, bool) {
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
@@ -86,15 +89,15 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
||||
if exists {
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(direction, size)
|
||||
return key, true
|
||||
return key, uint16(conn.DNATOrigPort.Load()), true
|
||||
}
|
||||
|
||||
return key, false
|
||||
return key, 0, false
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound connections
|
||||
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
|
||||
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) {
|
||||
key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
|
||||
if exists {
|
||||
return
|
||||
}
|
||||
@@ -109,6 +112,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
}
|
||||
conn.DNATOrigPort.Store(uint32(origPort))
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(direction, size)
|
||||
|
||||
@@ -116,7 +120,11 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.logger.Trace2("New %s UDP connection: %s", direction, key)
|
||||
if origPort != 0 {
|
||||
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
|
||||
} else {
|
||||
t.logger.Trace2("New %s UDP connection: %s", direction, key)
|
||||
}
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||
}
|
||||
|
||||
|
||||
@@ -109,6 +109,10 @@ type Manager struct {
|
||||
dnatMappings map[netip.Addr]netip.Addr
|
||||
dnatMutex sync.RWMutex
|
||||
dnatBiMap *biDNATMap
|
||||
|
||||
portDNATEnabled atomic.Bool
|
||||
portDNATRules []portDNATRule
|
||||
portDNATMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@@ -122,6 +126,8 @@ type decoder struct {
|
||||
icmp6 layers.ICMPv6
|
||||
decoded []gopacket.LayerType
|
||||
parser *gopacket.DecodingLayerParser
|
||||
|
||||
dnatOrigPort uint16
|
||||
}
|
||||
|
||||
// Create userspace firewall manager constructor
|
||||
@@ -196,6 +202,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
netstack: netstack.IsEnabled(),
|
||||
localForwarding: enableLocalForwarding,
|
||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||
portDNATRules: []portDNATRule{},
|
||||
}
|
||||
m.routingEnabled.Store(false)
|
||||
|
||||
@@ -630,7 +637,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
m.trackOutbound(d, srcIP, dstIP, size)
|
||||
m.trackOutbound(d, srcIP, dstIP, packetData, size)
|
||||
m.translateOutboundDNAT(packetData, d)
|
||||
|
||||
return false
|
||||
@@ -674,14 +681,26 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
|
||||
return flags
|
||||
}
|
||||
|
||||
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
|
||||
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) {
|
||||
transport := d.decoded[1]
|
||||
switch transport {
|
||||
case layers.LayerTypeUDP:
|
||||
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
|
||||
origPort := m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
|
||||
if origPort == 0 {
|
||||
break
|
||||
}
|
||||
if err := m.rewriteUDPPort(packetData, d, origPort, sourcePortOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite UDP port: %v", err)
|
||||
}
|
||||
case layers.LayerTypeTCP:
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
|
||||
origPort := m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
|
||||
if origPort == 0 {
|
||||
break
|
||||
}
|
||||
if err := m.rewriteTCPPort(packetData, d, origPort, sourcePortOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite TCP port: %v", err)
|
||||
}
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
|
||||
}
|
||||
@@ -691,13 +710,15 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
||||
transport := d.decoded[1]
|
||||
switch transport {
|
||||
case layers.LayerTypeUDP:
|
||||
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size)
|
||||
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size, d.dnatOrigPort)
|
||||
case layers.LayerTypeTCP:
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
|
||||
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort)
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
|
||||
}
|
||||
|
||||
d.dnatOrigPort = 0
|
||||
}
|
||||
|
||||
// udpHooksDrop checks if any UDP hooks should drop the packet
|
||||
@@ -759,10 +780,20 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// TODO: optimize port DNAT by caching matched rules in conntrack
|
||||
if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated {
|
||||
// Re-decode after port DNAT translation to update port information
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
m.logger.Error1("failed to re-decode packet after port DNAT: %v", err)
|
||||
return true
|
||||
}
|
||||
srcIP, dstIP = m.extractIPs(d)
|
||||
}
|
||||
|
||||
if translated := m.translateInboundReverse(packetData, d); translated {
|
||||
// Re-decode after translation to get original addresses
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
m.logger.Error1("Failed to re-decode packet after reverse DNAT: %v", err)
|
||||
m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err)
|
||||
return true
|
||||
}
|
||||
srcIP, dstIP = m.extractIPs(d)
|
||||
|
||||
@@ -50,6 +50,8 @@ type logMessage struct {
|
||||
arg4 any
|
||||
arg5 any
|
||||
arg6 any
|
||||
arg7 any
|
||||
arg8 any
|
||||
}
|
||||
|
||||
// Logger is a high-performance, non-blocking logger
|
||||
@@ -94,7 +96,6 @@ func (l *Logger) SetLevel(level Level) {
|
||||
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
||||
}
|
||||
|
||||
|
||||
func (l *Logger) Error(format string) {
|
||||
if l.level.Load() >= uint32(LevelError) {
|
||||
select {
|
||||
@@ -185,6 +186,15 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) {
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Trace1(format string, arg1 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
@@ -239,6 +249,16 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
|
||||
}
|
||||
}
|
||||
|
||||
// Trace8 logs a trace message with 8 arguments (8 placeholder in format string)
|
||||
func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
|
||||
*buf = (*buf)[:0]
|
||||
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
||||
@@ -260,6 +280,12 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
|
||||
argCount++
|
||||
if msg.arg6 != nil {
|
||||
argCount++
|
||||
if msg.arg7 != nil {
|
||||
argCount++
|
||||
if msg.arg8 != nil {
|
||||
argCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -283,6 +309,10 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
|
||||
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5)
|
||||
case 6:
|
||||
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6)
|
||||
case 7:
|
||||
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7)
|
||||
case 8:
|
||||
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7, msg.arg8)
|
||||
}
|
||||
|
||||
*buf = append(*buf, formatted...)
|
||||
@@ -390,4 +420,4 @@ func (l *Logger) Stop(ctx context.Context) error {
|
||||
case <-done:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -13,6 +15,21 @@ import (
|
||||
|
||||
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
|
||||
|
||||
var (
|
||||
errInvalidIPHeaderLength = errors.New("invalid IP header length")
|
||||
)
|
||||
|
||||
const (
|
||||
// Port offsets in TCP/UDP headers
|
||||
sourcePortOffset = 0
|
||||
destinationPortOffset = 2
|
||||
|
||||
// IP address offsets in IPv4 header
|
||||
sourceIPOffset = 12
|
||||
destinationIPOffset = 16
|
||||
)
|
||||
|
||||
// ipv4Checksum calculates IPv4 header checksum.
|
||||
func ipv4Checksum(header []byte) uint16 {
|
||||
if len(header) < 20 {
|
||||
return 0
|
||||
@@ -52,6 +69,7 @@ func ipv4Checksum(header []byte) uint16 {
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
// icmpChecksum calculates ICMP checksum.
|
||||
func icmpChecksum(data []byte) uint16 {
|
||||
var sum1, sum2, sum3, sum4 uint32
|
||||
i := 0
|
||||
@@ -89,11 +107,21 @@ func icmpChecksum(data []byte) uint16 {
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
// biDNATMap maintains bidirectional DNAT mappings.
|
||||
type biDNATMap struct {
|
||||
forward map[netip.Addr]netip.Addr
|
||||
reverse map[netip.Addr]netip.Addr
|
||||
}
|
||||
|
||||
// portDNATRule represents a port-specific DNAT rule.
|
||||
type portDNATRule struct {
|
||||
protocol gopacket.LayerType
|
||||
origPort uint16
|
||||
targetPort uint16
|
||||
targetIP netip.Addr
|
||||
}
|
||||
|
||||
// newBiDNATMap creates a new bidirectional DNAT mapping structure.
|
||||
func newBiDNATMap() *biDNATMap {
|
||||
return &biDNATMap{
|
||||
forward: make(map[netip.Addr]netip.Addr),
|
||||
@@ -101,11 +129,13 @@ func newBiDNATMap() *biDNATMap {
|
||||
}
|
||||
}
|
||||
|
||||
// set adds a bidirectional DNAT mapping between original and translated addresses.
|
||||
func (b *biDNATMap) set(original, translated netip.Addr) {
|
||||
b.forward[original] = translated
|
||||
b.reverse[translated] = original
|
||||
}
|
||||
|
||||
// delete removes a bidirectional DNAT mapping for the given original address.
|
||||
func (b *biDNATMap) delete(original netip.Addr) {
|
||||
if translated, exists := b.forward[original]; exists {
|
||||
delete(b.forward, original)
|
||||
@@ -113,19 +143,25 @@ func (b *biDNATMap) delete(original netip.Addr) {
|
||||
}
|
||||
}
|
||||
|
||||
// getTranslated returns the translated address for a given original address.
|
||||
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
|
||||
translated, exists := b.forward[original]
|
||||
return translated, exists
|
||||
}
|
||||
|
||||
// getOriginal returns the original address for a given translated address.
|
||||
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
|
||||
original, exists := b.reverse[translated]
|
||||
return original, exists
|
||||
}
|
||||
|
||||
// AddInternalDNATMapping adds a 1:1 IP address mapping for internal DNAT translation.
|
||||
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
|
||||
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
|
||||
return fmt.Errorf("invalid IP addresses")
|
||||
if !originalAddr.IsValid() {
|
||||
return fmt.Errorf("invalid original IP address")
|
||||
}
|
||||
if !translatedAddr.IsValid() {
|
||||
return fmt.Errorf("invalid translated IP address")
|
||||
}
|
||||
|
||||
if m.localipmanager.IsLocalIP(translatedAddr) {
|
||||
@@ -135,7 +171,6 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
|
||||
m.dnatMutex.Lock()
|
||||
defer m.dnatMutex.Unlock()
|
||||
|
||||
// Initialize both maps together if either is nil
|
||||
if m.dnatMappings == nil || m.dnatBiMap == nil {
|
||||
m.dnatMappings = make(map[netip.Addr]netip.Addr)
|
||||
m.dnatBiMap = newBiDNATMap()
|
||||
@@ -151,7 +186,7 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInternalDNATMapping removes a 1:1 IP address mapping
|
||||
// RemoveInternalDNATMapping removes a 1:1 IP address mapping.
|
||||
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
|
||||
m.dnatMutex.Lock()
|
||||
defer m.dnatMutex.Unlock()
|
||||
@@ -169,7 +204,7 @@ func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// getDNATTranslation returns the translated address if a mapping exists
|
||||
// getDNATTranslation returns the translated address if a mapping exists.
|
||||
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return addr, false
|
||||
@@ -181,7 +216,7 @@ func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
|
||||
return translated, exists
|
||||
}
|
||||
|
||||
// findReverseDNATMapping finds original address for return traffic
|
||||
// findReverseDNATMapping finds original address for return traffic.
|
||||
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return translatedAddr, false
|
||||
@@ -193,16 +228,12 @@ func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr,
|
||||
return original, exists
|
||||
}
|
||||
|
||||
// translateOutboundDNAT applies DNAT translation to outbound packets
|
||||
// translateOutboundDNAT applies DNAT translation to outbound packets.
|
||||
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
return false
|
||||
}
|
||||
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
|
||||
translatedIP, exists := m.getDNATTranslation(dstIP)
|
||||
@@ -210,8 +241,8 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
|
||||
m.logger.Error1("Failed to rewrite packet destination: %v", err)
|
||||
if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite packet destination: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -219,16 +250,12 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// translateInboundReverse applies reverse DNAT to inbound return traffic
|
||||
// translateInboundReverse applies reverse DNAT to inbound return traffic.
|
||||
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
|
||||
originalIP, exists := m.findReverseDNATMapping(srcIP)
|
||||
@@ -236,8 +263,8 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
|
||||
m.logger.Error1("Failed to rewrite packet source: %v", err)
|
||||
if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite packet source: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -245,21 +272,21 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// rewritePacketDestination replaces destination IP in the packet
|
||||
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||
// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums.
|
||||
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error {
|
||||
if !newIP.Is4() {
|
||||
return ErrIPv4Only
|
||||
}
|
||||
|
||||
var oldDst [4]byte
|
||||
copy(oldDst[:], packetData[16:20])
|
||||
newDst := newIP.As4()
|
||||
var oldIP [4]byte
|
||||
copy(oldIP[:], packetData[ipOffset:ipOffset+4])
|
||||
newIPBytes := newIP.As4()
|
||||
|
||||
copy(packetData[16:20], newDst[:])
|
||||
copy(packetData[ipOffset:ipOffset+4], newIPBytes[:])
|
||||
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return fmt.Errorf("invalid IP header length")
|
||||
return errInvalidIPHeaderLength
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||
@@ -269,44 +296,9 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP
|
||||
if len(d.decoded) > 1 {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
|
||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
|
||||
case layers.LayerTypeUDP:
|
||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// rewritePacketSource replaces the source IP address in the packet
|
||||
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||
return ErrIPv4Only
|
||||
}
|
||||
|
||||
var oldSrc [4]byte
|
||||
copy(oldSrc[:], packetData[12:16])
|
||||
newSrc := newIP.As4()
|
||||
|
||||
copy(packetData[12:16], newSrc[:])
|
||||
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return fmt.Errorf("invalid IP header length")
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
||||
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
||||
|
||||
if len(d.decoded) > 1 {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
|
||||
case layers.LayerTypeUDP:
|
||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
|
||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||
}
|
||||
@@ -315,6 +307,7 @@ func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateTCPChecksum updates TCP checksum after IP address change per RFC 1624.
|
||||
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||
tcpStart := ipHeaderLen
|
||||
if len(packetData) < tcpStart+18 {
|
||||
@@ -327,6 +320,7 @@ func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
// updateUDPChecksum updates UDP checksum after IP address change per RFC 1624.
|
||||
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||
udpStart := ipHeaderLen
|
||||
if len(packetData) < udpStart+8 {
|
||||
@@ -344,6 +338,7 @@ func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
// updateICMPChecksum recalculates ICMP checksum after packet modification.
|
||||
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
|
||||
icmpStart := ipHeaderLen
|
||||
if len(packetData) < icmpStart+8 {
|
||||
@@ -356,7 +351,7 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
|
||||
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
|
||||
}
|
||||
|
||||
// incrementalUpdate performs incremental checksum update per RFC 1624
|
||||
// incrementalUpdate performs incremental checksum update per RFC 1624.
|
||||
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||
sum := uint32(^oldChecksum)
|
||||
|
||||
@@ -391,7 +386,7 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
|
||||
// AddDNATRule adds outbound DNAT rule for forwarding external traffic to NetBird network.
|
||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil, errNatNotSupported
|
||||
@@ -399,10 +394,184 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
||||
return m.nativeFirewall.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
|
||||
// DeleteDNATRule deletes outbound DNAT rule.
|
||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errNatNotSupported
|
||||
}
|
||||
return m.nativeFirewall.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
// addPortRedirection adds a port redirection rule.
|
||||
func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
|
||||
m.portDNATMutex.Lock()
|
||||
defer m.portDNATMutex.Unlock()
|
||||
|
||||
rule := portDNATRule{
|
||||
protocol: protocol,
|
||||
origPort: sourcePort,
|
||||
targetPort: targetPort,
|
||||
targetIP: targetIP,
|
||||
}
|
||||
|
||||
m.portDNATRules = append(m.portDNATRules, rule)
|
||||
m.portDNATEnabled.Store(true)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
var layerType gopacket.LayerType
|
||||
switch protocol {
|
||||
case firewall.ProtocolTCP:
|
||||
layerType = layers.LayerTypeTCP
|
||||
case firewall.ProtocolUDP:
|
||||
layerType = layers.LayerTypeUDP
|
||||
default:
|
||||
return fmt.Errorf("unsupported protocol: %s", protocol)
|
||||
}
|
||||
|
||||
return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// removePortRedirection removes a port redirection rule.
|
||||
func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
|
||||
m.portDNATMutex.Lock()
|
||||
defer m.portDNATMutex.Unlock()
|
||||
|
||||
m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool {
|
||||
return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0
|
||||
})
|
||||
|
||||
if len(m.portDNATRules) == 0 {
|
||||
m.portDNATEnabled.Store(false)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
var layerType gopacket.LayerType
|
||||
switch protocol {
|
||||
case firewall.ProtocolTCP:
|
||||
layerType = layers.LayerTypeTCP
|
||||
case firewall.ProtocolUDP:
|
||||
layerType = layers.LayerTypeUDP
|
||||
default:
|
||||
return fmt.Errorf("unsupported protocol: %s", protocol)
|
||||
}
|
||||
|
||||
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
|
||||
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||
if !m.portDNATEnabled.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
dstPort := uint16(d.tcp.DstPort)
|
||||
return m.applyPortRule(packetData, d, srcIP, dstIP, dstPort, layers.LayerTypeTCP, m.rewriteTCPPort)
|
||||
case layers.LayerTypeUDP:
|
||||
dstPort := uint16(d.udp.DstPort)
|
||||
return m.applyPortRule(packetData, d, netip.Addr{}, dstIP, dstPort, layers.LayerTypeUDP, m.rewriteUDPPort)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
type portRewriteFunc func(packetData []byte, d *decoder, newPort uint16, portOffset int) error
|
||||
|
||||
func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, port uint16, protocol gopacket.LayerType, rewriteFn portRewriteFunc) bool {
|
||||
m.portDNATMutex.RLock()
|
||||
defer m.portDNATMutex.RUnlock()
|
||||
|
||||
for _, rule := range m.portDNATRules {
|
||||
if rule.protocol != protocol || rule.targetIP.Compare(dstIP) != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if rule.targetPort == port && rule.targetIP.Compare(srcIP) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if rule.origPort != port {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite port: %v", err)
|
||||
return false
|
||||
}
|
||||
d.dnatOrigPort = rule.origPort
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum.
|
||||
func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return errInvalidIPHeaderLength
|
||||
}
|
||||
|
||||
tcpStart := ipHeaderLen
|
||||
if len(packetData) < tcpStart+4 {
|
||||
return fmt.Errorf("packet too short for TCP header")
|
||||
}
|
||||
|
||||
portStart := tcpStart + portOffset
|
||||
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
|
||||
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
|
||||
|
||||
if len(packetData) >= tcpStart+18 {
|
||||
checksumOffset := tcpStart + 16
|
||||
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||
|
||||
var oldPortBytes, newPortBytes [2]byte
|
||||
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
|
||||
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
|
||||
|
||||
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum.
|
||||
func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return errInvalidIPHeaderLength
|
||||
}
|
||||
|
||||
udpStart := ipHeaderLen
|
||||
if len(packetData) < udpStart+8 {
|
||||
return fmt.Errorf("packet too short for UDP header")
|
||||
}
|
||||
|
||||
portStart := udpStart + portOffset
|
||||
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
|
||||
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
|
||||
|
||||
checksumOffset := udpStart + 6
|
||||
if len(packetData) >= udpStart+8 {
|
||||
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||
if oldChecksum != 0 {
|
||||
var oldPortBytes, newPortBytes [2]byte
|
||||
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
|
||||
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
|
||||
|
||||
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -414,3 +414,127 @@ func BenchmarkChecksumOptimizations(b *testing.B) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkPortDNAT measures the performance of port DNAT operations
|
||||
func BenchmarkPortDNAT(b *testing.B) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
proto layers.IPProtocol
|
||||
setupDNAT bool
|
||||
useMatchPort bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "tcp_inbound_dnat_match",
|
||||
proto: layers.IPProtocolTCP,
|
||||
setupDNAT: true,
|
||||
useMatchPort: true,
|
||||
description: "TCP inbound port DNAT translation (22 → 22022)",
|
||||
},
|
||||
{
|
||||
name: "tcp_inbound_dnat_nomatch",
|
||||
proto: layers.IPProtocolTCP,
|
||||
setupDNAT: true,
|
||||
useMatchPort: false,
|
||||
description: "TCP inbound with DNAT configured but no port match",
|
||||
},
|
||||
{
|
||||
name: "tcp_inbound_no_dnat",
|
||||
proto: layers.IPProtocolTCP,
|
||||
setupDNAT: false,
|
||||
useMatchPort: false,
|
||||
description: "TCP inbound without DNAT (baseline)",
|
||||
},
|
||||
{
|
||||
name: "udp_inbound_dnat_match",
|
||||
proto: layers.IPProtocolUDP,
|
||||
setupDNAT: true,
|
||||
useMatchPort: true,
|
||||
description: "UDP inbound port DNAT translation (5353 → 22054)",
|
||||
},
|
||||
{
|
||||
name: "udp_inbound_dnat_nomatch",
|
||||
proto: layers.IPProtocolUDP,
|
||||
setupDNAT: true,
|
||||
useMatchPort: false,
|
||||
description: "UDP inbound with DNAT configured but no port match",
|
||||
},
|
||||
{
|
||||
name: "udp_inbound_no_dnat",
|
||||
proto: layers.IPProtocolUDP,
|
||||
setupDNAT: false,
|
||||
useMatchPort: false,
|
||||
description: "UDP inbound without DNAT (baseline)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, sc := range scenarios {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Set logger to error level to reduce noise during benchmarking
|
||||
manager.SetLogLevel(log.ErrorLevel)
|
||||
defer func() {
|
||||
// Restore to info level after benchmark
|
||||
manager.SetLogLevel(log.InfoLevel)
|
||||
}()
|
||||
|
||||
localAddr := netip.MustParseAddr("100.0.2.175")
|
||||
clientIP := netip.MustParseAddr("100.0.169.249")
|
||||
|
||||
var origPort, targetPort, testPort uint16
|
||||
if sc.proto == layers.IPProtocolTCP {
|
||||
origPort, targetPort = 22, 22022
|
||||
} else {
|
||||
origPort, targetPort = 5353, 22054
|
||||
}
|
||||
|
||||
if sc.useMatchPort {
|
||||
testPort = origPort
|
||||
} else {
|
||||
testPort = 443 // Different port
|
||||
}
|
||||
|
||||
// Setup port DNAT mapping if needed
|
||||
if sc.setupDNAT {
|
||||
err := manager.AddInboundDNAT(localAddr, protocolToFirewall(sc.proto), origPort, targetPort)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
// Pre-establish inbound connection for outbound reverse test
|
||||
if sc.setupDNAT && sc.useMatchPort {
|
||||
inboundPacket := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, origPort)
|
||||
manager.filterInbound(inboundPacket, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
// Benchmark inbound DNAT translation
|
||||
b.Run("inbound", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Create fresh packet each time
|
||||
packet := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, testPort)
|
||||
manager.filterInbound(packet, 0)
|
||||
}
|
||||
})
|
||||
|
||||
// Benchmark outbound reverse DNAT translation (only if DNAT is set up and port matches)
|
||||
if sc.setupDNAT && sc.useMatchPort {
|
||||
b.Run("outbound_reverse", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Create fresh return packet (from target port)
|
||||
packet := generateDNATTestPacket(b, localAddr, clientIP, sc.proto, targetPort, 54321)
|
||||
manager.filterOutbound(packet, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
@@ -143,3 +144,111 @@ func TestDNATMappingManagement(t *testing.T) {
|
||||
err = manager.RemoveInternalDNATMapping(originalIP)
|
||||
require.Error(t, err, "Should error when removing non-existent mapping")
|
||||
}
|
||||
|
||||
func TestInboundPortDNAT(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
localAddr := netip.MustParseAddr("100.0.2.175")
|
||||
clientIP := netip.MustParseAddr("100.0.169.249")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
protocol layers.IPProtocol
|
||||
sourcePort uint16
|
||||
targetPort uint16
|
||||
}{
|
||||
{"TCP SSH", layers.IPProtocolTCP, 22, 22022},
|
||||
{"UDP DNS", layers.IPProtocolUDP, 5353, 22054},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := manager.AddInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort)
|
||||
require.NoError(t, err)
|
||||
|
||||
inboundPacket := generateDNATTestPacket(t, clientIP, localAddr, tc.protocol, 54321, tc.sourcePort)
|
||||
d := parsePacket(t, inboundPacket)
|
||||
|
||||
translated := manager.translateInboundPortDNAT(inboundPacket, d, clientIP, localAddr)
|
||||
require.True(t, translated, "Inbound packet should be translated")
|
||||
|
||||
d = parsePacket(t, inboundPacket)
|
||||
var dstPort uint16
|
||||
switch tc.protocol {
|
||||
case layers.IPProtocolTCP:
|
||||
dstPort = uint16(d.tcp.DstPort)
|
||||
case layers.IPProtocolUDP:
|
||||
dstPort = uint16(d.udp.DstPort)
|
||||
}
|
||||
|
||||
require.Equal(t, tc.targetPort, dstPort, "Destination port should be rewritten to target port")
|
||||
|
||||
err = manager.RemoveInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInboundPortDNATNegative(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
localAddr := netip.MustParseAddr("100.0.2.175")
|
||||
clientIP := netip.MustParseAddr("100.0.169.249")
|
||||
|
||||
err = manager.AddInboundDNAT(localAddr, firewall.ProtocolTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
protocol layers.IPProtocol
|
||||
srcIP netip.Addr
|
||||
dstIP netip.Addr
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
}{
|
||||
{"Wrong port", layers.IPProtocolTCP, clientIP, localAddr, 54321, 80},
|
||||
{"Wrong IP", layers.IPProtocolTCP, clientIP, netip.MustParseAddr("100.64.0.99"), 54321, 22},
|
||||
{"Wrong protocol", layers.IPProtocolUDP, clientIP, localAddr, 54321, 22},
|
||||
{"ICMP", layers.IPProtocolICMPv4, clientIP, localAddr, 0, 0},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
packet := generateDNATTestPacket(t, tc.srcIP, tc.dstIP, tc.protocol, tc.srcPort, tc.dstPort)
|
||||
d := parsePacket(t, packet)
|
||||
|
||||
translated := manager.translateInboundPortDNAT(packet, d, tc.srcIP, tc.dstIP)
|
||||
require.False(t, translated, "Packet should NOT be translated for %s", tc.name)
|
||||
|
||||
d = parsePacket(t, packet)
|
||||
if tc.protocol == layers.IPProtocolTCP {
|
||||
require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged")
|
||||
} else if tc.protocol == layers.IPProtocolUDP {
|
||||
require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func protocolToFirewall(proto layers.IPProtocol) firewall.Protocol {
|
||||
switch proto {
|
||||
case layers.IPProtocolTCP:
|
||||
return firewall.ProtocolTCP
|
||||
case layers.IPProtocolUDP:
|
||||
return firewall.ProtocolUDP
|
||||
default:
|
||||
return firewall.ProtocolALL
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,25 +16,33 @@ type PacketStage int
|
||||
|
||||
const (
|
||||
StageReceived PacketStage = iota
|
||||
StageInboundPortDNAT
|
||||
StageInbound1to1NAT
|
||||
StageConntrack
|
||||
StagePeerACL
|
||||
StageRouting
|
||||
StageRouteACL
|
||||
StageForwarding
|
||||
StageCompleted
|
||||
StageOutbound1to1NAT
|
||||
StageOutboundPortReverse
|
||||
)
|
||||
|
||||
const msgProcessingCompleted = "Processing completed"
|
||||
|
||||
func (s PacketStage) String() string {
|
||||
return map[PacketStage]string{
|
||||
StageReceived: "Received",
|
||||
StageConntrack: "Connection Tracking",
|
||||
StagePeerACL: "Peer ACL",
|
||||
StageRouting: "Routing",
|
||||
StageRouteACL: "Route ACL",
|
||||
StageForwarding: "Forwarding",
|
||||
StageCompleted: "Completed",
|
||||
StageReceived: "Received",
|
||||
StageInboundPortDNAT: "Inbound Port DNAT",
|
||||
StageInbound1to1NAT: "Inbound 1:1 NAT",
|
||||
StageConntrack: "Connection Tracking",
|
||||
StagePeerACL: "Peer ACL",
|
||||
StageRouting: "Routing",
|
||||
StageRouteACL: "Route ACL",
|
||||
StageForwarding: "Forwarding",
|
||||
StageCompleted: "Completed",
|
||||
StageOutbound1to1NAT: "Outbound 1:1 NAT",
|
||||
StageOutboundPortReverse: "Outbound DNAT Reverse",
|
||||
}[s]
|
||||
}
|
||||
|
||||
@@ -261,6 +269,10 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
||||
}
|
||||
|
||||
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
|
||||
if m.handleInboundDNAT(trace, packetData, d, &srcIP, &dstIP) {
|
||||
return trace
|
||||
}
|
||||
|
||||
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
||||
return trace
|
||||
}
|
||||
@@ -400,7 +412,16 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
|
||||
}
|
||||
|
||||
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||
// will create or update the connection state
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
trace.AddResult(StageCompleted, "Packet dropped - decode error", false)
|
||||
return trace
|
||||
}
|
||||
|
||||
m.handleOutboundDNAT(trace, packetData, d)
|
||||
|
||||
dropped := m.filterOutbound(packetData, 0)
|
||||
if dropped {
|
||||
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||
@@ -409,3 +430,199 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
|
||||
}
|
||||
return trace
|
||||
}
|
||||
|
||||
func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool {
|
||||
portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d)
|
||||
if portDNATApplied {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false)
|
||||
return true
|
||||
}
|
||||
*srcIP, *dstIP = m.extractIPs(d)
|
||||
trace.DestinationPort = m.getDestPort(d)
|
||||
}
|
||||
|
||||
nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d)
|
||||
if nat1to1Applied {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false)
|
||||
return true
|
||||
}
|
||||
*srcIP, *dstIP = m.extractIPs(d)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) traceInboundPortDNAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
|
||||
if !m.portDNATEnabled.Load() {
|
||||
trace.AddResult(StageInboundPortDNAT, "Port DNAT not enabled", true)
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
trace.AddResult(StageInboundPortDNAT, "Not IPv4, skipping port DNAT", true)
|
||||
return false
|
||||
}
|
||||
|
||||
if len(d.decoded) < 2 {
|
||||
trace.AddResult(StageInboundPortDNAT, "No transport layer, skipping port DNAT", true)
|
||||
return false
|
||||
}
|
||||
|
||||
protocol := d.decoded[1]
|
||||
if protocol != layers.LayerTypeTCP && protocol != layers.LayerTypeUDP {
|
||||
trace.AddResult(StageInboundPortDNAT, "Not TCP/UDP, skipping port DNAT", true)
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
var originalPort uint16
|
||||
if protocol == layers.LayerTypeTCP {
|
||||
originalPort = uint16(d.tcp.DstPort)
|
||||
} else {
|
||||
originalPort = uint16(d.udp.DstPort)
|
||||
}
|
||||
|
||||
translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP)
|
||||
if translated {
|
||||
ipHeaderLen := int((packetData[0] & 0x0F) * 4)
|
||||
translatedPort := uint16(packetData[ipHeaderLen+2])<<8 | uint16(packetData[ipHeaderLen+3])
|
||||
|
||||
protoStr := "TCP"
|
||||
if protocol == layers.LayerTypeUDP {
|
||||
protoStr = "UDP"
|
||||
}
|
||||
msg := fmt.Sprintf("%s port DNAT applied: %s:%d -> %s:%d", protoStr, dstIP, originalPort, dstIP, translatedPort)
|
||||
trace.AddResult(StageInboundPortDNAT, msg, true)
|
||||
return true
|
||||
}
|
||||
|
||||
trace.AddResult(StageInboundPortDNAT, "No matching port DNAT rule", true)
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
trace.AddResult(StageInbound1to1NAT, "1:1 NAT not enabled", true)
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
|
||||
translated := m.translateInboundReverse(packetData, d)
|
||||
if translated {
|
||||
m.dnatMutex.RLock()
|
||||
translatedIP, exists := m.dnatBiMap.getOriginal(srcIP)
|
||||
m.dnatMutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
msg := fmt.Sprintf("1:1 NAT reverse applied: %s -> %s", srcIP, translatedIP)
|
||||
trace.AddResult(StageInbound1to1NAT, msg, true)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
trace.AddResult(StageInbound1to1NAT, "No matching 1:1 NAT rule", true)
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) handleOutboundDNAT(trace *PacketTrace, packetData []byte, d *decoder) {
|
||||
m.traceOutbound1to1NAT(trace, packetData, d)
|
||||
m.traceOutboundPortReverse(trace, packetData, d)
|
||||
}
|
||||
|
||||
func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
trace.AddResult(StageOutbound1to1NAT, "1:1 NAT not enabled", true)
|
||||
return false
|
||||
}
|
||||
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
|
||||
translated := m.translateOutboundDNAT(packetData, d)
|
||||
if translated {
|
||||
m.dnatMutex.RLock()
|
||||
translatedIP, exists := m.dnatMappings[dstIP]
|
||||
m.dnatMutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
msg := fmt.Sprintf("1:1 NAT applied: %s -> %s", dstIP, translatedIP)
|
||||
trace.AddResult(StageOutbound1to1NAT, msg, true)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
trace.AddResult(StageOutbound1to1NAT, "No matching 1:1 NAT rule", true)
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) traceOutboundPortReverse(trace *PacketTrace, packetData []byte, d *decoder) bool {
|
||||
if !m.portDNATEnabled.Load() {
|
||||
trace.AddResult(StageOutboundPortReverse, "Port DNAT not enabled", true)
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
trace.AddResult(StageOutboundPortReverse, "Not IPv4, skipping port reverse", true)
|
||||
return false
|
||||
}
|
||||
|
||||
if len(d.decoded) < 2 {
|
||||
trace.AddResult(StageOutboundPortReverse, "No transport layer, skipping port reverse", true)
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
|
||||
var origPort uint16
|
||||
transport := d.decoded[1]
|
||||
switch transport {
|
||||
case layers.LayerTypeTCP:
|
||||
srcPort := uint16(d.tcp.SrcPort)
|
||||
dstPort := uint16(d.tcp.DstPort)
|
||||
conn, exists := m.tcpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort)
|
||||
if exists {
|
||||
origPort = uint16(conn.DNATOrigPort.Load())
|
||||
}
|
||||
if origPort != 0 {
|
||||
msg := fmt.Sprintf("TCP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort)
|
||||
trace.AddResult(StageOutboundPortReverse, msg, true)
|
||||
return true
|
||||
}
|
||||
case layers.LayerTypeUDP:
|
||||
srcPort := uint16(d.udp.SrcPort)
|
||||
dstPort := uint16(d.udp.DstPort)
|
||||
conn, exists := m.udpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort)
|
||||
if exists {
|
||||
origPort = uint16(conn.DNATOrigPort.Load())
|
||||
}
|
||||
if origPort != 0 {
|
||||
msg := fmt.Sprintf("UDP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort)
|
||||
trace.AddResult(StageOutboundPortReverse, msg, true)
|
||||
return true
|
||||
}
|
||||
default:
|
||||
trace.AddResult(StageOutboundPortReverse, "Not TCP/UDP, skipping port reverse", true)
|
||||
return false
|
||||
}
|
||||
|
||||
trace.AddResult(StageOutboundPortReverse, "No tracked connection for DNAT reverse", true)
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) getDestPort(d *decoder) uint16 {
|
||||
if len(d.decoded) < 2 {
|
||||
return 0
|
||||
}
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
return uint16(d.tcp.DstPort)
|
||||
case layers.LayerTypeUDP:
|
||||
return uint16(d.udp.DstPort)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,6 +104,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -126,6 +128,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -153,6 +157,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -179,6 +185,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -204,6 +212,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageRouteACL,
|
||||
@@ -228,6 +238,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageRouteACL,
|
||||
@@ -246,6 +258,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageRouteACL,
|
||||
@@ -264,6 +278,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageCompleted,
|
||||
@@ -287,6 +303,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageCompleted,
|
||||
},
|
||||
@@ -301,6 +319,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageOutbound1to1NAT,
|
||||
StageOutboundPortReverse,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
@@ -319,6 +339,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -340,6 +362,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -362,6 +386,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -382,6 +408,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -406,6 +434,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
|
||||
@@ -47,7 +47,7 @@ nftables.txt: Anonymized nftables rules with packet counters, if --system-info f
|
||||
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
|
||||
config.txt: Anonymized configuration information of the NetBird client.
|
||||
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
||||
state.json: Anonymized client state dump containing netbird states.
|
||||
state.json: Anonymized client state dump containing netbird states for the active profile.
|
||||
mutex.prof: Mutex profiling information.
|
||||
goroutine.prof: Goroutine profiling information.
|
||||
block.prof: Block profiling information.
|
||||
@@ -564,6 +564,8 @@ func (g *BundleGenerator) addStateFile() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debugf("Adding state file from: %s", path)
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
@@ -50,28 +51,21 @@ func (s *systemConfigurator) supportCustomPort() bool {
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||
var err error
|
||||
|
||||
if err := stateManager.UpdateState(&ShutdownState{}); err != nil {
|
||||
log.Errorf("failed to update shutdown state: %s", err)
|
||||
}
|
||||
|
||||
var (
|
||||
searchDomains []string
|
||||
matchDomains []string
|
||||
)
|
||||
|
||||
err = s.recordSystemDNSSettings(true)
|
||||
if err != nil {
|
||||
if err := s.recordSystemDNSSettings(true); err != nil {
|
||||
log.Errorf("unable to update record of System's DNS config: %s", err.Error())
|
||||
}
|
||||
|
||||
if config.RouteAll {
|
||||
searchDomains = append(searchDomains, "\"\"")
|
||||
err = s.addLocalDNS()
|
||||
if err != nil {
|
||||
log.Infof("failed to enable split DNS")
|
||||
if err := s.addLocalDNS(); err != nil {
|
||||
log.Warnf("failed to add local DNS: %v", err)
|
||||
}
|
||||
s.updateState(stateManager)
|
||||
}
|
||||
|
||||
for _, dConf := range config.Domains {
|
||||
@@ -86,6 +80,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
||||
}
|
||||
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
var err error
|
||||
if len(matchDomains) != 0 {
|
||||
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
|
||||
} else {
|
||||
@@ -95,6 +90,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
||||
if err != nil {
|
||||
return fmt.Errorf("add match domains: %w", err)
|
||||
}
|
||||
s.updateState(stateManager)
|
||||
|
||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||
if len(searchDomains) != 0 {
|
||||
@@ -106,6 +102,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
||||
if err != nil {
|
||||
return fmt.Errorf("add search domains: %w", err)
|
||||
}
|
||||
s.updateState(stateManager)
|
||||
|
||||
if err := s.flushDNSCache(); err != nil {
|
||||
log.Errorf("failed to flush DNS cache: %v", err)
|
||||
@@ -114,6 +111,12 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) updateState(stateManager *statemanager.Manager) {
|
||||
if err := stateManager.UpdateState(&ShutdownState{CreatedKeys: maps.Keys(s.createdKeys)}); err != nil {
|
||||
log.Errorf("failed to update shutdown state: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) string() string {
|
||||
return "scutil"
|
||||
}
|
||||
@@ -167,18 +170,20 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
||||
func (s *systemConfigurator) addLocalDNS() error {
|
||||
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
|
||||
if err := s.recordSystemDNSSettings(true); err != nil {
|
||||
log.Errorf("Unable to get system DNS configuration")
|
||||
return fmt.Errorf("recordSystemDNSSettings(): %w", err)
|
||||
}
|
||||
}
|
||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||
if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 {
|
||||
err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't add local network DNS conf: %w", err)
|
||||
}
|
||||
} else {
|
||||
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
|
||||
log.Info("Not enabling local DNS server")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.addSearchDomains(
|
||||
localKey,
|
||||
strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort,
|
||||
); err != nil {
|
||||
return fmt.Errorf("add search domains: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
111
client/internal/dns/host_darwin_test.go
Normal file
111
client/internal/dns/host_darwin_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
//go:build !ios
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping scutil integration test in short mode")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
stateFile := filepath.Join(tmpDir, "state.json")
|
||||
|
||||
sm := statemanager.New(stateFile)
|
||||
sm.RegisterState(&ShutdownState{})
|
||||
sm.Start()
|
||||
defer func() {
|
||||
require.NoError(t, sm.Stop(context.Background()))
|
||||
}()
|
||||
|
||||
configurator := &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
config := HostDNSConfig{
|
||||
ServerIP: netip.MustParseAddr("100.64.0.1"),
|
||||
ServerPort: 53,
|
||||
RouteAll: true,
|
||||
Domains: []DomainConfig{
|
||||
{Domain: "example.com", MatchOnly: true},
|
||||
},
|
||||
}
|
||||
|
||||
err := configurator.applyDNSConfig(config, sm)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, sm.PersistState(context.Background()))
|
||||
|
||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||
|
||||
defer func() {
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
_ = removeTestDNSKey(key)
|
||||
}
|
||||
}()
|
||||
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
exists, err := checkDNSKeyExists(key)
|
||||
require.NoError(t, err)
|
||||
if exists {
|
||||
t.Logf("Key %s exists before cleanup", key)
|
||||
}
|
||||
}
|
||||
|
||||
sm2 := statemanager.New(stateFile)
|
||||
sm2.RegisterState(&ShutdownState{})
|
||||
err = sm2.LoadState(&ShutdownState{})
|
||||
require.NoError(t, err)
|
||||
|
||||
state := sm2.GetState(&ShutdownState{})
|
||||
if state == nil {
|
||||
t.Skip("State not saved, skipping cleanup test")
|
||||
}
|
||||
|
||||
shutdownState, ok := state.(*ShutdownState)
|
||||
require.True(t, ok)
|
||||
|
||||
err = shutdownState.Cleanup()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
exists, err := checkDNSKeyExists(key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
|
||||
}
|
||||
}
|
||||
|
||||
func checkDNSKeyExists(key string) (bool, error) {
|
||||
cmd := exec.Command(scutilPath)
|
||||
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
if strings.Contains(string(output), "No such key") {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return !strings.Contains(string(output), "No such key"), nil
|
||||
}
|
||||
|
||||
func removeTestDNSKey(key string) error {
|
||||
cmd := exec.Command(scutilPath)
|
||||
cmd.Stdin = strings.NewReader("remove " + key + "\nquit\n")
|
||||
_, err := cmd.CombinedOutput()
|
||||
return err
|
||||
}
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/winregistry"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -178,13 +179,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||
}
|
||||
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
Guid: r.guid,
|
||||
GPO: r.gpo,
|
||||
NRPTEntryCount: r.nrptEntryCount,
|
||||
}); err != nil {
|
||||
log.Errorf("failed to update shutdown state: %s", err)
|
||||
}
|
||||
r.updateState(stateManager)
|
||||
|
||||
var searchDomains, matchDomains []string
|
||||
for _, dConf := range config.Domains {
|
||||
@@ -197,6 +192,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
|
||||
}
|
||||
|
||||
if err := r.removeDNSMatchPolicies(); err != nil {
|
||||
log.Errorf("cleanup old dns match policies: %s", err)
|
||||
}
|
||||
|
||||
if len(matchDomains) != 0 {
|
||||
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
|
||||
if err != nil {
|
||||
@@ -204,19 +203,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
}
|
||||
r.nrptEntryCount = count
|
||||
} else {
|
||||
if err := r.removeDNSMatchPolicies(); err != nil {
|
||||
return fmt.Errorf("remove dns match policies: %w", err)
|
||||
}
|
||||
r.nrptEntryCount = 0
|
||||
}
|
||||
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
Guid: r.guid,
|
||||
GPO: r.gpo,
|
||||
NRPTEntryCount: r.nrptEntryCount,
|
||||
}); err != nil {
|
||||
log.Errorf("failed to update shutdown state: %s", err)
|
||||
}
|
||||
r.updateState(stateManager)
|
||||
|
||||
if err := r.updateSearchDomains(searchDomains); err != nil {
|
||||
return fmt.Errorf("update search domains: %w", err)
|
||||
@@ -227,6 +217,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) {
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
Guid: r.guid,
|
||||
GPO: r.gpo,
|
||||
NRPTEntryCount: r.nrptEntryCount,
|
||||
}); err != nil {
|
||||
log.Errorf("failed to update shutdown state: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
|
||||
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil {
|
||||
return fmt.Errorf("adding dns setup for all failed: %w", err)
|
||||
@@ -273,9 +273,9 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s
|
||||
return fmt.Errorf("remove existing dns policy: %w", err)
|
||||
}
|
||||
|
||||
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE)
|
||||
regKey, _, err := winregistry.CreateVolatileKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err)
|
||||
return fmt.Errorf("create volatile registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err)
|
||||
}
|
||||
defer closer(regKey)
|
||||
|
||||
|
||||
102
client/internal/dns/host_windows_test.go
Normal file
102
client/internal/dns/host_windows_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
|
||||
// when the number of match domains decreases between configuration changes.
|
||||
func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping registry integration test in short mode")
|
||||
}
|
||||
|
||||
defer cleanupRegistryKeys(t)
|
||||
cleanupRegistryKeys(t)
|
||||
|
||||
testIP := netip.MustParseAddr("100.64.0.1")
|
||||
|
||||
// Create a test interface registry key so updateSearchDomains doesn't fail
|
||||
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
|
||||
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
|
||||
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
|
||||
require.NoError(t, err, "Should create test interface registry key")
|
||||
testKey.Close()
|
||||
defer func() {
|
||||
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
|
||||
}()
|
||||
|
||||
cfg := ®istryConfigurator{
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
}
|
||||
|
||||
config5 := HostDNSConfig{
|
||||
ServerIP: testIP,
|
||||
Domains: []DomainConfig{
|
||||
{Domain: "domain1.com", MatchOnly: true},
|
||||
{Domain: "domain2.com", MatchOnly: true},
|
||||
{Domain: "domain3.com", MatchOnly: true},
|
||||
{Domain: "domain4.com", MatchOnly: true},
|
||||
{Domain: "domain5.com", MatchOnly: true},
|
||||
},
|
||||
}
|
||||
|
||||
err = cfg.applyDNSConfig(config5, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all 5 entries exist
|
||||
for i := 0; i < 5; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Entry %d should exist after first config", i)
|
||||
}
|
||||
|
||||
config2 := HostDNSConfig{
|
||||
ServerIP: testIP,
|
||||
Domains: []DomainConfig{
|
||||
{Domain: "domain1.com", MatchOnly: true},
|
||||
{Domain: "domain2.com", MatchOnly: true},
|
||||
},
|
||||
}
|
||||
|
||||
err = cfg.applyDNSConfig(config2, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify first 2 entries exist
|
||||
for i := 0; i < 2; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Entry %d should exist after second config", i)
|
||||
}
|
||||
|
||||
// Verify entries 2-4 are cleaned up
|
||||
for i := 2; i < 5; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i)
|
||||
}
|
||||
}
|
||||
|
||||
func registryKeyExists(path string) (bool, error) {
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
if err == registry.ErrNotExist {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
k.Close()
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func cleanupRegistryKeys(*testing.T) {
|
||||
cfg := ®istryConfigurator{nrptEntryCount: 10}
|
||||
_ = cfg.removeDNSMatchPolicies()
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
)
|
||||
|
||||
type ShutdownState struct {
|
||||
CreatedKeys []string
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Name() string {
|
||||
@@ -19,6 +20,10 @@ func (s *ShutdownState) Cleanup() error {
|
||||
return fmt.Errorf("create host manager: %w", err)
|
||||
}
|
||||
|
||||
for _, key := range s.CreatedKeys {
|
||||
manager.createdKeys[key] = struct{}{}
|
||||
}
|
||||
|
||||
if err := manager.restoreUncleanShutdownDNS(); err != nil {
|
||||
return fmt.Errorf("restore unclean shutdown dns: %w", err)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -12,18 +14,14 @@ import (
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
var (
|
||||
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
|
||||
listenPort uint16 = 5353
|
||||
listenPortMu sync.RWMutex
|
||||
)
|
||||
|
||||
const (
|
||||
dnsTTL = 60 //seconds
|
||||
dnsTTL = 60
|
||||
envServerPort = "NB_DNS_FORWARDER_PORT"
|
||||
)
|
||||
|
||||
// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list.
|
||||
@@ -36,28 +34,30 @@ type ForwarderEntry struct {
|
||||
type Manager struct {
|
||||
firewall firewall.Manager
|
||||
statusRecorder *peer.Status
|
||||
localAddr netip.Addr
|
||||
serverPort uint16
|
||||
|
||||
fwRules []firewall.Rule
|
||||
tcpRules []firewall.Rule
|
||||
dnsForwarder *DNSForwarder
|
||||
}
|
||||
|
||||
func ListenPort() uint16 {
|
||||
listenPortMu.RLock()
|
||||
defer listenPortMu.RUnlock()
|
||||
return listenPort
|
||||
}
|
||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, localAddr netip.Addr) *Manager {
|
||||
serverPort := nbdns.ForwarderServerPort
|
||||
if envPort := os.Getenv(envServerPort); envPort != "" {
|
||||
if port, err := strconv.ParseUint(envPort, 10, 16); err == nil && port > 0 {
|
||||
serverPort = uint16(port)
|
||||
log.Infof("using custom DNS forwarder port from %s: %d", envServerPort, serverPort)
|
||||
} else {
|
||||
log.Warnf("invalid %s value %q, using default %d", envServerPort, envPort, nbdns.ForwarderServerPort)
|
||||
}
|
||||
}
|
||||
|
||||
func SetListenPort(port uint16) {
|
||||
listenPortMu.Lock()
|
||||
listenPort = port
|
||||
listenPortMu.Unlock()
|
||||
}
|
||||
|
||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
|
||||
return &Manager{
|
||||
firewall: fw,
|
||||
statusRecorder: statusRecorder,
|
||||
localAddr: localAddr,
|
||||
serverPort: serverPort,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,7 +71,21 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
|
||||
return err
|
||||
}
|
||||
|
||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder)
|
||||
if m.localAddr.IsValid() && m.firewall != nil {
|
||||
if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
log.Warnf("failed to add DNS UDP DNAT rule: %v", err)
|
||||
} else {
|
||||
log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort)
|
||||
}
|
||||
|
||||
if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
log.Warnf("failed to add DNS TCP DNAT rule: %v", err)
|
||||
} else {
|
||||
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort)
|
||||
}
|
||||
}
|
||||
|
||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", m.serverPort), dnsTTL, m.firewall, m.statusRecorder)
|
||||
go func() {
|
||||
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
|
||||
// todo handle close error if it is exists
|
||||
@@ -96,6 +110,17 @@ func (m *Manager) Stop(ctx context.Context) error {
|
||||
}
|
||||
|
||||
var mErr *multierror.Error
|
||||
|
||||
if m.localAddr.IsValid() && m.firewall != nil {
|
||||
if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS UDP DNAT rule: %w", err))
|
||||
}
|
||||
|
||||
if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.dropDNSFirewall(); err != nil {
|
||||
mErr = multierror.Append(mErr, err)
|
||||
}
|
||||
@@ -111,7 +136,7 @@ func (m *Manager) Stop(ctx context.Context) error {
|
||||
func (m *Manager) allowDNSFirewall() error {
|
||||
dport := &firewall.Port{
|
||||
IsRange: false,
|
||||
Values: []uint16{ListenPort()},
|
||||
Values: []uint16{m.serverPort},
|
||||
}
|
||||
|
||||
if m.firewall == nil {
|
||||
|
||||
@@ -203,8 +203,7 @@ type Engine struct {
|
||||
wgIfaceMonitor *WGIfaceMonitor
|
||||
wgIfaceMonitorWg sync.WaitGroup
|
||||
|
||||
// dns forwarder port
|
||||
dnsFwdPort uint16
|
||||
probeStunTurn *relay.StunTurnProbe
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -247,7 +246,7 @@ func NewEngine(
|
||||
statusRecorder: statusRecorder,
|
||||
checks: checks,
|
||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||
dnsFwdPort: dnsfwd.ListenPort(),
|
||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||
}
|
||||
|
||||
sm := profilemanager.NewServiceManager("")
|
||||
@@ -1060,10 +1059,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||
}
|
||||
|
||||
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
|
||||
dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)
|
||||
|
||||
if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
}
|
||||
|
||||
e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort)
|
||||
|
||||
// apply routes first, route related actions might depend on routing being enabled
|
||||
routes := toRoutes(networkMap.GetRoutes())
|
||||
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
|
||||
@@ -1084,7 +1087,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
}
|
||||
|
||||
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
|
||||
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort))
|
||||
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
|
||||
|
||||
// Ingress forward rules
|
||||
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
|
||||
@@ -1208,10 +1211,16 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
|
||||
}
|
||||
|
||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
|
||||
forwarderPort := uint16(protoDNSConfig.GetForwarderPort())
|
||||
if forwarderPort == 0 {
|
||||
forwarderPort = nbdns.ForwarderClientPort
|
||||
}
|
||||
|
||||
dnsUpdate := nbdns.Config{
|
||||
ServiceEnable: protoDNSConfig.GetServiceEnable(),
|
||||
CustomZones: make([]nbdns.CustomZone, 0),
|
||||
NameServerGroups: make([]*nbdns.NameServerGroup, 0),
|
||||
ForwarderPort: forwarderPort,
|
||||
}
|
||||
|
||||
for _, zone := range protoDNSConfig.GetCustomZones() {
|
||||
@@ -1667,7 +1676,7 @@ func (e *Engine) getRosenpassAddr() string {
|
||||
|
||||
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
|
||||
// and updates the status recorder with the latest states.
|
||||
func (e *Engine) RunHealthProbes() bool {
|
||||
func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
||||
e.syncMsgMux.Lock()
|
||||
|
||||
signalHealthy := e.signal.IsHealthy()
|
||||
@@ -1699,8 +1708,12 @@ func (e *Engine) RunHealthProbes() bool {
|
||||
}
|
||||
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
results := e.probeICE(stuns, turns)
|
||||
var results []relay.ProbeResult
|
||||
if waitForResult {
|
||||
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
|
||||
} else {
|
||||
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
|
||||
}
|
||||
e.statusRecorder.UpdateRelayStates(results)
|
||||
|
||||
relayHealthy := true
|
||||
@@ -1717,13 +1730,6 @@ func (e *Engine) RunHealthProbes() bool {
|
||||
return allHealthy
|
||||
}
|
||||
|
||||
func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
|
||||
return append(
|
||||
relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns),
|
||||
relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)...,
|
||||
)
|
||||
}
|
||||
|
||||
// restartEngine restarts the engine by cancelling the client context
|
||||
func (e *Engine) restartEngine() {
|
||||
e.syncMsgMux.Lock()
|
||||
@@ -1843,16 +1849,11 @@ func (e *Engine) GetWgAddr() netip.Addr {
|
||||
func (e *Engine) updateDNSForwarder(
|
||||
enabled bool,
|
||||
fwdEntries []*dnsfwd.ForwarderEntry,
|
||||
forwarderPort uint16,
|
||||
) {
|
||||
if e.config.DisableServerRoutes {
|
||||
return
|
||||
}
|
||||
|
||||
if forwarderPort > 0 {
|
||||
dnsfwd.SetListenPort(forwarderPort)
|
||||
}
|
||||
|
||||
if !enabled {
|
||||
if e.dnsForwardMgr == nil {
|
||||
return
|
||||
@@ -1864,20 +1865,17 @@ func (e *Engine) updateDNSForwarder(
|
||||
}
|
||||
|
||||
if len(fwdEntries) > 0 {
|
||||
switch {
|
||||
case e.dnsForwardMgr == nil:
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
|
||||
if e.dnsForwardMgr == nil {
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, localAddr)
|
||||
|
||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
log.Infof("started domain router service with %d entries", len(fwdEntries))
|
||||
case e.dnsFwdPort != forwarderPort:
|
||||
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
|
||||
e.restartDnsFwd(fwdEntries, forwarderPort)
|
||||
e.dnsFwdPort = forwarderPort
|
||||
|
||||
default:
|
||||
log.Infof("started domain router service with %d entries", len(fwdEntries))
|
||||
} else {
|
||||
e.dnsForwardMgr.UpdateDomains(fwdEntries)
|
||||
}
|
||||
} else if e.dnsForwardMgr != nil {
|
||||
@@ -1887,20 +1885,6 @@ func (e *Engine) updateDNSForwarder(
|
||||
}
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) {
|
||||
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
|
||||
// stop and start the forwarder to apply the new port
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
|
||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) GetNet() (*netstack.Net, error) {
|
||||
|
||||
@@ -10,10 +10,10 @@ import (
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow/store"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
type rcvChan chan *types.EventFields
|
||||
@@ -138,7 +138,8 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) {
|
||||
|
||||
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
|
||||
// check dns collection
|
||||
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == uint16(dnsfwd.ListenPort())) {
|
||||
if !l.dnsCollection.Load() && event.Protocol == types.UDP &&
|
||||
(event.DestPort == 53 || event.DestPort == dns.ForwarderClientPort || event.DestPort == dns.ForwarderServerPort) {
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@ package relay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
@@ -15,6 +17,15 @@ import (
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultCacheTTL = 20 * time.Second
|
||||
probeTimeout = 6 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
ErrCheckInProgress = errors.New("probe check is already in progress")
|
||||
)
|
||||
|
||||
// ProbeResult holds the info about the result of a relay probe request
|
||||
type ProbeResult struct {
|
||||
URI string
|
||||
@@ -22,8 +33,164 @@ type ProbeResult struct {
|
||||
Addr string
|
||||
}
|
||||
|
||||
type StunTurnProbe struct {
|
||||
cacheResults []ProbeResult
|
||||
cacheTimestamp time.Time
|
||||
cacheKey string
|
||||
cacheTTL time.Duration
|
||||
probeInProgress bool
|
||||
probeDone chan struct{}
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewStunTurnProbe(cacheTTL time.Duration) *StunTurnProbe {
|
||||
return &StunTurnProbe{
|
||||
cacheTTL: cacheTTL,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *StunTurnProbe) ProbeAllWaitResult(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
|
||||
cacheKey := generateCacheKey(stuns, turns)
|
||||
|
||||
p.mu.Lock()
|
||||
if p.probeInProgress {
|
||||
doneChan := p.probeDone
|
||||
p.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Debugf("Context cancelled while waiting for probe results")
|
||||
return createErrorResults(stuns, turns)
|
||||
case <-doneChan:
|
||||
return p.getCachedResults(cacheKey, stuns, turns)
|
||||
}
|
||||
}
|
||||
|
||||
p.probeInProgress = true
|
||||
probeDone := make(chan struct{})
|
||||
p.probeDone = probeDone
|
||||
p.mu.Unlock()
|
||||
|
||||
p.doProbe(ctx, stuns, turns, cacheKey)
|
||||
close(probeDone)
|
||||
|
||||
return p.getCachedResults(cacheKey, stuns, turns)
|
||||
}
|
||||
|
||||
// ProbeAll probes all given servers asynchronously and returns the results
|
||||
func (p *StunTurnProbe) ProbeAll(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
|
||||
cacheKey := generateCacheKey(stuns, turns)
|
||||
|
||||
p.mu.Lock()
|
||||
|
||||
if results := p.checkCache(cacheKey); results != nil {
|
||||
p.mu.Unlock()
|
||||
return results
|
||||
}
|
||||
|
||||
if p.probeInProgress {
|
||||
p.mu.Unlock()
|
||||
return createErrorResults(stuns, turns)
|
||||
}
|
||||
|
||||
p.probeInProgress = true
|
||||
probeDone := make(chan struct{})
|
||||
p.probeDone = probeDone
|
||||
log.Infof("started new probe for STUN, TURN servers")
|
||||
go func() {
|
||||
p.doProbe(ctx, stuns, turns, cacheKey)
|
||||
close(probeDone)
|
||||
}()
|
||||
|
||||
p.mu.Unlock()
|
||||
|
||||
timer := time.NewTimer(1300 * time.Millisecond)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Debugf("Context cancelled while waiting for probe results")
|
||||
return createErrorResults(stuns, turns)
|
||||
case <-probeDone:
|
||||
// when the probe is return fast, return the results right away
|
||||
return p.getCachedResults(cacheKey, stuns, turns)
|
||||
case <-timer.C:
|
||||
// if the probe takes longer than 1.3s, return error results to avoid blocking
|
||||
return createErrorResults(stuns, turns)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *StunTurnProbe) checkCache(cacheKey string) []ProbeResult {
|
||||
if p.cacheKey == cacheKey && len(p.cacheResults) > 0 {
|
||||
age := time.Since(p.cacheTimestamp)
|
||||
if age < p.cacheTTL {
|
||||
results := append([]ProbeResult(nil), p.cacheResults...)
|
||||
log.Debugf("returning cached probe results (age: %v)", age)
|
||||
return results
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *StunTurnProbe) getCachedResults(cacheKey string, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.cacheKey == cacheKey && len(p.cacheResults) > 0 {
|
||||
return append([]ProbeResult(nil), p.cacheResults...)
|
||||
}
|
||||
return createErrorResults(stuns, turns)
|
||||
}
|
||||
|
||||
func (p *StunTurnProbe) doProbe(ctx context.Context, stuns []*stun.URI, turns []*stun.URI, cacheKey string) {
|
||||
defer func() {
|
||||
p.mu.Lock()
|
||||
p.probeInProgress = false
|
||||
p.mu.Unlock()
|
||||
}()
|
||||
results := make([]ProbeResult, len(stuns)+len(turns))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i, uri := range stuns {
|
||||
wg.Add(1)
|
||||
go func(idx int, stunURI *stun.URI) {
|
||||
defer wg.Done()
|
||||
|
||||
probeCtx, cancel := context.WithTimeout(ctx, probeTimeout)
|
||||
defer cancel()
|
||||
|
||||
results[idx].URI = stunURI.String()
|
||||
results[idx].Addr, results[idx].Err = p.probeSTUN(probeCtx, stunURI)
|
||||
}(i, uri)
|
||||
}
|
||||
|
||||
stunOffset := len(stuns)
|
||||
for i, uri := range turns {
|
||||
wg.Add(1)
|
||||
go func(idx int, turnURI *stun.URI) {
|
||||
defer wg.Done()
|
||||
|
||||
probeCtx, cancel := context.WithTimeout(ctx, probeTimeout)
|
||||
defer cancel()
|
||||
|
||||
results[idx].URI = turnURI.String()
|
||||
results[idx].Addr, results[idx].Err = p.probeTURN(probeCtx, turnURI)
|
||||
}(stunOffset+i, uri)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
p.mu.Lock()
|
||||
p.cacheResults = results
|
||||
p.cacheTimestamp = time.Now()
|
||||
p.cacheKey = cacheKey
|
||||
p.mu.Unlock()
|
||||
|
||||
log.Debug("Stored new probe results in cache")
|
||||
}
|
||||
|
||||
// ProbeSTUN tries binding to the given STUN uri and acquiring an address
|
||||
func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
|
||||
func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
|
||||
defer func() {
|
||||
if probeErr != nil {
|
||||
log.Debugf("stun probe error from %s: %s", uri, probeErr)
|
||||
@@ -83,7 +250,7 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
|
||||
}
|
||||
|
||||
// ProbeTURN tries allocating a session from the given TURN URI
|
||||
func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
|
||||
func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
|
||||
defer func() {
|
||||
if probeErr != nil {
|
||||
log.Debugf("turn probe error from %s: %s", uri, probeErr)
|
||||
@@ -160,28 +327,28 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
|
||||
return relayConn.LocalAddr().String(), nil
|
||||
}
|
||||
|
||||
// ProbeAll probes all given servers asynchronously and returns the results
|
||||
func ProbeAll(
|
||||
ctx context.Context,
|
||||
fn func(ctx context.Context, uri *stun.URI) (addr string, probeErr error),
|
||||
relays []*stun.URI,
|
||||
) []ProbeResult {
|
||||
results := make([]ProbeResult, len(relays))
|
||||
func createErrorResults(stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
|
||||
total := len(stuns) + len(turns)
|
||||
results := make([]ProbeResult, total)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i, uri := range relays {
|
||||
ctx, cancel := context.WithTimeout(ctx, 6*time.Second)
|
||||
defer cancel()
|
||||
|
||||
wg.Add(1)
|
||||
go func(res *ProbeResult, stunURI *stun.URI) {
|
||||
defer wg.Done()
|
||||
res.URI = stunURI.String()
|
||||
res.Addr, res.Err = fn(ctx, stunURI)
|
||||
}(&results[i], uri)
|
||||
allURIs := append(append([]*stun.URI{}, stuns...), turns...)
|
||||
for i, uri := range allURIs {
|
||||
results[i] = ProbeResult{
|
||||
URI: uri.String(),
|
||||
Err: ErrCheckInProgress,
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
func generateCacheKey(stuns []*stun.URI, turns []*stun.URI) string {
|
||||
h := sha256.New()
|
||||
for _, uri := range stuns {
|
||||
h.Write([]byte(uri.String()))
|
||||
}
|
||||
for _, uri := range turns {
|
||||
h.Write([]byte(uri.String()))
|
||||
}
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -25,4 +26,5 @@ type HandlerParams struct {
|
||||
UseNewDNSRoute bool
|
||||
Firewall manager.Manager
|
||||
FakeIPManager *fakeip.Manager
|
||||
ForwarderPort *atomic.Uint32
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
@@ -18,7 +19,6 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
@@ -55,6 +55,7 @@ type DnsInterceptor struct {
|
||||
peerStore *peerstore.Store
|
||||
firewall firewall.Manager
|
||||
fakeIPManager *fakeip.Manager
|
||||
forwarderPort *atomic.Uint32
|
||||
}
|
||||
|
||||
func New(params common.HandlerParams) *DnsInterceptor {
|
||||
@@ -69,6 +70,7 @@ func New(params common.HandlerParams) *DnsInterceptor {
|
||||
firewall: params.Firewall,
|
||||
fakeIPManager: params.FakeIPManager,
|
||||
interceptedDomains: make(domainMap),
|
||||
forwarderPort: params.ForwarderPort,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -257,7 +259,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort())
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), uint16(d.forwarderPort.Load()))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"runtime"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -23,6 +24,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/client"
|
||||
@@ -54,6 +56,7 @@ type Manager interface {
|
||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||
InitialRouteRange() []string
|
||||
SetFirewall(firewall.Manager) error
|
||||
SetDNSForwarderPort(port uint16)
|
||||
Stop(stateManager *statemanager.Manager)
|
||||
}
|
||||
|
||||
@@ -101,12 +104,13 @@ type DefaultManager struct {
|
||||
disableServerRoutes bool
|
||||
activeRoutes map[route.HAUniqueID]client.RouteHandler
|
||||
fakeIPManager *fakeip.Manager
|
||||
dnsForwarderPort atomic.Uint32
|
||||
}
|
||||
|
||||
func NewManager(config ManagerConfig) *DefaultManager {
|
||||
mCTX, cancel := context.WithCancel(config.Context)
|
||||
notifier := notifier.NewNotifier()
|
||||
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
|
||||
sysOps := systemops.New(config.WGInterface, notifier)
|
||||
|
||||
if runtime.GOOS == "windows" && config.WGInterface != nil {
|
||||
nbnet.SetVPNInterfaceName(config.WGInterface.Name())
|
||||
@@ -130,6 +134,7 @@ func NewManager(config ManagerConfig) *DefaultManager {
|
||||
disableServerRoutes: config.DisableServerRoutes,
|
||||
activeRoutes: make(map[route.HAUniqueID]client.RouteHandler),
|
||||
}
|
||||
dm.dnsForwarderPort.Store(uint32(nbdns.ForwarderClientPort))
|
||||
|
||||
useNoop := netstack.IsEnabled() || config.DisableClientRoutes
|
||||
dm.setupRefCounters(useNoop)
|
||||
@@ -270,6 +275,11 @@ func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDNSForwarderPort sets the DNS forwarder port for route handlers
|
||||
func (m *DefaultManager) SetDNSForwarderPort(port uint16) {
|
||||
m.dnsForwarderPort.Store(uint32(port))
|
||||
}
|
||||
|
||||
// Stop stops the manager watchers and clean firewall rules
|
||||
func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
||||
m.stop()
|
||||
@@ -345,6 +355,7 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
|
||||
UseNewDNSRoute: m.useNewDNSRoute,
|
||||
Firewall: m.firewall,
|
||||
FakeIPManager: m.fakeIPManager,
|
||||
ForwarderPort: &m.dnsForwarderPort,
|
||||
}
|
||||
handler := client.HandlerFromRoute(params)
|
||||
if err := handler.AddRoute(m.ctx); err != nil {
|
||||
|
||||
@@ -90,6 +90,10 @@ func (m *MockManager) SetFirewall(firewall.Manager) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
// SetDNSForwarderPort mock implementation of SetDNSForwarderPort from Manager interface
|
||||
func (m *MockManager) SetDNSForwarderPort(port uint16) {
|
||||
}
|
||||
|
||||
// Stop mock implementation of Stop from Manager interface
|
||||
func (m *MockManager) Stop(stateManager *statemanager.Manager) {
|
||||
if m.StopFunc != nil {
|
||||
|
||||
8
client/internal/routemanager/systemops/flush_nonbsd.go
Normal file
8
client/internal/routemanager/systemops/flush_nonbsd.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build !((darwin && !ios) || dragonfly || freebsd || netbsd || openbsd)
|
||||
|
||||
package systemops
|
||||
|
||||
// FlushMarkedRoutes is a no-op on non-BSD platforms.
|
||||
func (r *SysOps) FlushMarkedRoutes() error {
|
||||
return nil
|
||||
}
|
||||
@@ -13,11 +13,11 @@ func (s *ShutdownState) Name() string {
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
sysops := NewSysOps(nil, nil)
|
||||
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
|
||||
sysops.refCounter.LoadData((*ExclusionCounter)(s))
|
||||
sysOps := New(nil, nil)
|
||||
sysOps.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysOps.removeFromRouteTable)
|
||||
sysOps.refCounter.LoadData((*ExclusionCounter)(s))
|
||||
|
||||
return sysops.refCounter.Flush()
|
||||
return sysOps.refCounter.Flush()
|
||||
}
|
||||
|
||||
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
|
||||
|
||||
@@ -83,7 +83,7 @@ type SysOps struct {
|
||||
localSubnetsCacheTime time.Time
|
||||
}
|
||||
|
||||
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
||||
func New(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
||||
return &SysOps{
|
||||
wgInterface: wgInterface,
|
||||
notifier: notifier,
|
||||
|
||||
@@ -42,7 +42,7 @@ func TestConcurrentRoutes(t *testing.T) {
|
||||
_, intf = setupDummyInterface(t)
|
||||
nexthop = Nexthop{netip.Addr{}, intf}
|
||||
|
||||
r := NewSysOps(nil, nil)
|
||||
r := New(nil, nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 1024; i++ {
|
||||
@@ -146,7 +146,7 @@ func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR strin
|
||||
|
||||
nexthop := Nexthop{netip.Addr{}, netIntf}
|
||||
|
||||
r := NewSysOps(nil, nil)
|
||||
r := New(nil, nil)
|
||||
err = r.addToRouteTable(prefix, nexthop)
|
||||
require.NoError(t, err, "Failed to add route to table")
|
||||
|
||||
|
||||
@@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) {
|
||||
|
||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
r := New(wgInterface, nil)
|
||||
advancedRouting := nbnet.AdvancedRouting()
|
||||
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||
require.NoError(t, err)
|
||||
@@ -342,7 +342,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
|
||||
|
||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
r := New(wgInterface, nil)
|
||||
advancedRouting := nbnet.AdvancedRouting()
|
||||
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||
require.NoError(t, err)
|
||||
@@ -486,7 +486,7 @@ func setupTestEnv(t *testing.T) {
|
||||
assert.NoError(t, wgInterface.Close())
|
||||
})
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
r := New(wgInterface, nil)
|
||||
advancedRouting := nbnet.AdvancedRouting()
|
||||
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||
require.NoError(t, err, "setupRouting should not return err")
|
||||
|
||||
@@ -7,19 +7,39 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const (
|
||||
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
|
||||
)
|
||||
|
||||
var routeProtoFlag int
|
||||
|
||||
func init() {
|
||||
switch os.Getenv(envRouteProtoFlag) {
|
||||
case "2":
|
||||
routeProtoFlag = unix.RTF_PROTO2
|
||||
case "3":
|
||||
routeProtoFlag = unix.RTF_PROTO3
|
||||
default:
|
||||
routeProtoFlag = unix.RTF_PROTO1
|
||||
}
|
||||
}
|
||||
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
}
|
||||
@@ -28,6 +48,62 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRout
|
||||
return r.cleanupRefCounter(stateManager)
|
||||
}
|
||||
|
||||
// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
|
||||
func (r *SysOps) FlushMarkedRoutes() error {
|
||||
rib, err := retryFetchRIB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetch routing table: %w", err)
|
||||
}
|
||||
|
||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse routing table: %w", err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
flushedCount := 0
|
||||
|
||||
for _, msg := range msgs {
|
||||
rtMsg, ok := msg.(*route.RouteMessage)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if rtMsg.Flags&routeProtoFlag == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
routeInfo, err := MsgToRoute(rtMsg)
|
||||
if err != nil {
|
||||
log.Debugf("Skipping route flush: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !routeInfo.Dst.IsValid() || !routeInfo.Dst.IsSingleIP() {
|
||||
continue
|
||||
}
|
||||
|
||||
nexthop := Nexthop{
|
||||
IP: routeInfo.Gw,
|
||||
Intf: routeInfo.Interface,
|
||||
}
|
||||
|
||||
if err := r.removeFromRouteTable(routeInfo.Dst, nexthop); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", routeInfo.Dst, err))
|
||||
continue
|
||||
}
|
||||
|
||||
flushedCount++
|
||||
log.Debugf("Flushed marked route: %s", routeInfo.Dst)
|
||||
}
|
||||
|
||||
if flushedCount > 0 {
|
||||
log.Infof("Flushed %d residual NetBird routes from previous session", flushedCount)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
|
||||
}
|
||||
@@ -105,7 +181,7 @@ func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func(
|
||||
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
|
||||
msg = &route.RouteMessage{
|
||||
Type: action,
|
||||
Flags: unix.RTF_UP,
|
||||
Flags: unix.RTF_UP | routeProtoFlag,
|
||||
Version: unix.RTM_VERSION,
|
||||
Seq: r.getSeq(),
|
||||
}
|
||||
|
||||
@@ -295,7 +295,7 @@ func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage,
|
||||
data, err := os.ReadFile(m.filePath)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
log.Debug("state file does not exist")
|
||||
log.Debugf("state file %s does not exist", m.filePath)
|
||||
return nil, nil // nolint:nilnil
|
||||
}
|
||||
return nil, fmt.Errorf("read state file: %w", err)
|
||||
|
||||
59
client/internal/winregistry/volatile_windows.go
Normal file
59
client/internal/winregistry/volatile_windows.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package winregistry
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
var (
|
||||
advapi = syscall.NewLazyDLL("advapi32.dll")
|
||||
regCreateKeyExW = advapi.NewProc("RegCreateKeyExW")
|
||||
)
|
||||
|
||||
const (
|
||||
// Registry key options
|
||||
regOptionNonVolatile = 0x0 // Key is preserved when system is rebooted
|
||||
regOptionVolatile = 0x1 // Key is not preserved when system is rebooted
|
||||
|
||||
// Registry disposition values
|
||||
regCreatedNewKey = 0x1
|
||||
regOpenedExistingKey = 0x2
|
||||
)
|
||||
|
||||
// CreateVolatileKey creates a volatile registry key named path under open key root.
|
||||
// CreateVolatileKey returns the new key and a boolean flag that reports whether the key already existed.
|
||||
// The access parameter specifies the access rights for the key to be created.
|
||||
//
|
||||
// Volatile keys are stored in memory and are automatically deleted when the system is shut down.
|
||||
// This provides automatic cleanup without requiring manual registry maintenance.
|
||||
func CreateVolatileKey(root registry.Key, path string, access uint32) (registry.Key, bool, error) {
|
||||
pathPtr, err := syscall.UTF16PtrFromString(path)
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
|
||||
var (
|
||||
handle syscall.Handle
|
||||
disposition uint32
|
||||
)
|
||||
|
||||
ret, _, _ := regCreateKeyExW.Call(
|
||||
uintptr(root),
|
||||
uintptr(unsafe.Pointer(pathPtr)),
|
||||
0, // reserved
|
||||
0, // class
|
||||
uintptr(regOptionVolatile), // options - volatile key
|
||||
uintptr(access), // desired access
|
||||
0, // security attributes
|
||||
uintptr(unsafe.Pointer(&handle)),
|
||||
uintptr(unsafe.Pointer(&disposition)),
|
||||
)
|
||||
|
||||
if ret != 0 {
|
||||
return 0, false, syscall.Errno(ret)
|
||||
}
|
||||
|
||||
return registry.Key(handle), disposition == regOpenedExistingKey, nil
|
||||
}
|
||||
@@ -1057,10 +1057,7 @@ func (s *Server) Status(
|
||||
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
|
||||
|
||||
if msg.GetFullPeerStatus {
|
||||
if msg.ShouldRunProbes {
|
||||
s.runProbes()
|
||||
}
|
||||
|
||||
s.runProbes(msg.ShouldRunProbes)
|
||||
fullStatus := s.statusRecorder.GetFullStatus()
|
||||
pbFullStatus := toProtoFullStatus(fullStatus)
|
||||
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
||||
@@ -1070,7 +1067,7 @@ func (s *Server) Status(
|
||||
return &statusResponse, nil
|
||||
}
|
||||
|
||||
func (s *Server) runProbes() {
|
||||
func (s *Server) runProbes(waitForProbeResult bool) {
|
||||
if s.connectClient == nil {
|
||||
return
|
||||
}
|
||||
@@ -1081,7 +1078,7 @@ func (s *Server) runProbes() {
|
||||
}
|
||||
|
||||
if time.Since(s.lastProbe) > probeThreshold {
|
||||
if engine.RunHealthProbes() {
|
||||
if engine.RunHealthProbes(waitForProbeResult) {
|
||||
s.lastProbe = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,9 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
@@ -135,5 +137,12 @@ func restoreResidualState(ctx context.Context, statePath string) error {
|
||||
merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err))
|
||||
}
|
||||
|
||||
// clean up any remaining routes independently of the state file
|
||||
if !nbnet.AdvancedRouting() {
|
||||
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
probeRelay "github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
@@ -340,10 +341,16 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
||||
for _, relay := range overview.Relays.Details {
|
||||
available := "Available"
|
||||
reason := ""
|
||||
|
||||
if !relay.Available {
|
||||
available = "Unavailable"
|
||||
reason = fmt.Sprintf(", reason: %s", relay.Error)
|
||||
if relay.Error == probeRelay.ErrCheckInProgress.Error() {
|
||||
available = "Checking..."
|
||||
} else {
|
||||
available = "Unavailable"
|
||||
reason = fmt.Sprintf(", reason: %s", relay.Error)
|
||||
}
|
||||
}
|
||||
|
||||
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -296,6 +296,8 @@ type serviceClient struct {
|
||||
mExitNodeDeselectAll *systray.MenuItem
|
||||
logFile string
|
||||
wLoginURL fyne.Window
|
||||
|
||||
connectCancel context.CancelFunc
|
||||
}
|
||||
|
||||
type menuHandler struct {
|
||||
@@ -592,17 +594,15 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
|
||||
func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) {
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
log.Errorf("get client: %v", err)
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("get daemon client: %w", err)
|
||||
}
|
||||
|
||||
activeProf, err := s.profileManager.GetActiveProfile()
|
||||
if err != nil {
|
||||
log.Errorf("get active profile: %v", err)
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("get active profile: %w", err)
|
||||
}
|
||||
|
||||
currUser, err := user.Current()
|
||||
@@ -610,84 +610,71 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
|
||||
return nil, fmt.Errorf("get current user: %w", err)
|
||||
}
|
||||
|
||||
loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{
|
||||
loginResp, err := conn.Login(ctx, &proto.LoginRequest{
|
||||
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
|
||||
ProfileName: &activeProf.Name,
|
||||
Username: &currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("login to management URL with: %v", err)
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("login to management: %w", err)
|
||||
}
|
||||
|
||||
if loginResp.NeedsSSOLogin && openURL {
|
||||
err = s.handleSSOLogin(loginResp, conn)
|
||||
if err != nil {
|
||||
log.Errorf("handle SSO login failed: %v", err)
|
||||
return nil, err
|
||||
if err = s.handleSSOLogin(ctx, loginResp, conn); err != nil {
|
||||
return nil, fmt.Errorf("SSO login: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error {
|
||||
err := openURL(loginResp.VerificationURIComplete)
|
||||
if err != nil {
|
||||
log.Errorf("opening the verification uri in the browser failed: %v", err)
|
||||
return err
|
||||
func (s *serviceClient) handleSSOLogin(ctx context.Context, loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error {
|
||||
if err := openURL(loginResp.VerificationURIComplete); err != nil {
|
||||
return fmt.Errorf("open browser: %w", err)
|
||||
}
|
||||
|
||||
resp, err := conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
|
||||
resp, err := conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
|
||||
if err != nil {
|
||||
log.Errorf("waiting sso login failed with: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("wait for SSO login: %w", err)
|
||||
}
|
||||
|
||||
if resp.Email != "" {
|
||||
err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{
|
||||
if err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{
|
||||
Email: resp.Email,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warnf("failed to set profile state: %v", err)
|
||||
}); err != nil {
|
||||
log.Debugf("failed to set profile state: %v", err)
|
||||
} else {
|
||||
s.mProfile.refresh()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceClient) menuUpClick() error {
|
||||
func (s *serviceClient) menuUpClick(ctx context.Context) error {
|
||||
systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting)
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
systray.SetTemplateIcon(iconErrorMacOS, s.icError)
|
||||
log.Errorf("get client: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("get daemon client: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.login(true)
|
||||
_, err = s.login(ctx, true)
|
||||
if err != nil {
|
||||
log.Errorf("login failed with: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("login: %w", err)
|
||||
}
|
||||
|
||||
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
|
||||
status, err := conn.Status(ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("get service status: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("get status: %w", err)
|
||||
}
|
||||
|
||||
if status.Status == string(internal.StatusConnected) {
|
||||
log.Warnf("already connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
|
||||
log.Errorf("up service: %v", err)
|
||||
return err
|
||||
if _, err := conn.Up(ctx, &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("start connection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -697,24 +684,20 @@ func (s *serviceClient) menuDownClick() error {
|
||||
systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting)
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
log.Errorf("get client: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("get daemon client: %w", err)
|
||||
}
|
||||
|
||||
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("get service status: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("get status: %w", err)
|
||||
}
|
||||
|
||||
if status.Status != string(internal.StatusConnected) && status.Status != string(internal.StatusConnecting) {
|
||||
log.Warnf("already down")
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := s.conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
|
||||
log.Errorf("down service: %v", err)
|
||||
return err
|
||||
if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("stop connection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1381,7 +1364,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := s.login(false)
|
||||
resp, err := s.login(ctx, false)
|
||||
if err != nil {
|
||||
log.Errorf("failed to fetch login URL: %v", err)
|
||||
return
|
||||
@@ -1401,7 +1384,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode})
|
||||
_, err = conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode})
|
||||
if err != nil {
|
||||
log.Errorf("Waiting sso login failed with: %v", err)
|
||||
label.SetText("Waiting login failed, please create \na debug bundle in the settings and contact support.")
|
||||
@@ -1409,7 +1392,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
|
||||
}
|
||||
|
||||
label.SetText("Re-authentication successful.\nReconnecting")
|
||||
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
|
||||
status, err := conn.Status(ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("get service status: %v", err)
|
||||
return
|
||||
@@ -1422,7 +1405,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = conn.Up(s.ctx, &proto.UpRequest{})
|
||||
_, err = conn.Up(ctx, &proto.UpRequest{})
|
||||
if err != nil {
|
||||
label.SetText("Reconnecting failed, please create \na debug bundle in the settings and contact support.")
|
||||
log.Errorf("Reconnecting failed with: %v", err)
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
uptypes "github.com/netbirdio/netbird/upload-server/types"
|
||||
@@ -426,6 +427,12 @@ func (s *serviceClient) collectDebugData(
|
||||
return "", err
|
||||
}
|
||||
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
log.Warnf("Failed to get post-up status: %v", err)
|
||||
@@ -433,7 +440,7 @@ func (s *serviceClient) collectDebugData(
|
||||
|
||||
var postUpStatusOutput string
|
||||
if postUpStatus != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", "")
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
||||
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
}
|
||||
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||
@@ -450,7 +457,7 @@ func (s *serviceClient) collectDebugData(
|
||||
|
||||
var preDownStatusOutput string
|
||||
if preDownStatus != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", "")
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
||||
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
}
|
||||
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
|
||||
@@ -574,6 +581,12 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
|
||||
return nil, fmt.Errorf("get client: %v", err)
|
||||
}
|
||||
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
log.Warnf("failed to get status for debug bundle: %v", err)
|
||||
@@ -581,7 +594,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
|
||||
|
||||
var statusOutput string
|
||||
if statusResp != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", "")
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
|
||||
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"fyne.io/fyne/v2"
|
||||
"fyne.io/systray"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
@@ -67,20 +69,55 @@ func (h *eventHandler) listen(ctx context.Context) {
|
||||
|
||||
func (h *eventHandler) handleConnectClick() {
|
||||
h.client.mUp.Disable()
|
||||
|
||||
if h.client.connectCancel != nil {
|
||||
h.client.connectCancel()
|
||||
}
|
||||
|
||||
connectCtx, connectCancel := context.WithCancel(h.client.ctx)
|
||||
h.client.connectCancel = connectCancel
|
||||
|
||||
go func() {
|
||||
defer h.client.mUp.Enable()
|
||||
if err := h.client.menuUpClick(); err != nil {
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service"))
|
||||
defer connectCancel()
|
||||
|
||||
if err := h.client.menuUpClick(connectCtx); err != nil {
|
||||
st, ok := status.FromError(err)
|
||||
if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) {
|
||||
log.Debugf("connect operation cancelled by user")
|
||||
} else {
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect"))
|
||||
log.Errorf("connect failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.client.updateStatus(); err != nil {
|
||||
log.Debugf("failed to update status after connect: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *eventHandler) handleDisconnectClick() {
|
||||
h.client.mDown.Disable()
|
||||
|
||||
if h.client.connectCancel != nil {
|
||||
log.Debugf("cancelling ongoing connect operation")
|
||||
h.client.connectCancel()
|
||||
h.client.connectCancel = nil
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer h.client.mDown.Enable()
|
||||
if err := h.client.menuDownClick(); err != nil {
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird daemon"))
|
||||
st, ok := status.FromError(err)
|
||||
if !errors.Is(err, context.Canceled) && !(ok && st.Code() == codes.Canceled) {
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to disconnect"))
|
||||
log.Errorf("disconnect failed: %v", err)
|
||||
} else {
|
||||
log.Debugf("disconnect cancelled or already disconnecting")
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.client.updateStatus(); err != nil {
|
||||
log.Debugf("failed to update status after disconnect: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -245,6 +282,6 @@ func (h *eventHandler) logout(ctx context.Context) error {
|
||||
}
|
||||
|
||||
h.client.getSrvConfig()
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -387,6 +387,7 @@ type subItem struct {
|
||||
type profileMenu struct {
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
serviceClient *serviceClient
|
||||
profileManager *profilemanager.ProfileManager
|
||||
eventHandler *eventHandler
|
||||
profileMenuItem *systray.MenuItem
|
||||
@@ -396,7 +397,7 @@ type profileMenu struct {
|
||||
logoutSubItem *subItem
|
||||
profilesState []Profile
|
||||
downClickCallback func() error
|
||||
upClickCallback func() error
|
||||
upClickCallback func(context.Context) error
|
||||
getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error)
|
||||
loadSettingsCallback func()
|
||||
app fyne.App
|
||||
@@ -404,12 +405,13 @@ type profileMenu struct {
|
||||
|
||||
type newProfileMenuArgs struct {
|
||||
ctx context.Context
|
||||
serviceClient *serviceClient
|
||||
profileManager *profilemanager.ProfileManager
|
||||
eventHandler *eventHandler
|
||||
profileMenuItem *systray.MenuItem
|
||||
emailMenuItem *systray.MenuItem
|
||||
downClickCallback func() error
|
||||
upClickCallback func() error
|
||||
upClickCallback func(context.Context) error
|
||||
getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error)
|
||||
loadSettingsCallback func()
|
||||
app fyne.App
|
||||
@@ -418,6 +420,7 @@ type newProfileMenuArgs struct {
|
||||
func newProfileMenu(args newProfileMenuArgs) *profileMenu {
|
||||
p := profileMenu{
|
||||
ctx: args.ctx,
|
||||
serviceClient: args.serviceClient,
|
||||
profileManager: args.profileManager,
|
||||
eventHandler: args.eventHandler,
|
||||
profileMenuItem: args.profileMenuItem,
|
||||
@@ -569,10 +572,19 @@ func (p *profileMenu) refresh() {
|
||||
}
|
||||
}
|
||||
|
||||
if err := p.upClickCallback(); err != nil {
|
||||
if p.serviceClient.connectCancel != nil {
|
||||
p.serviceClient.connectCancel()
|
||||
}
|
||||
|
||||
connectCtx, connectCancel := context.WithCancel(p.ctx)
|
||||
p.serviceClient.connectCancel = connectCancel
|
||||
|
||||
if err := p.upClickCallback(connectCtx); err != nil {
|
||||
log.Errorf("failed to handle up click after switching profile: %v", err)
|
||||
}
|
||||
|
||||
connectCancel()
|
||||
|
||||
p.refresh()
|
||||
p.loadSettingsCallback()
|
||||
}
|
||||
|
||||
@@ -19,6 +19,10 @@ const (
|
||||
RootZone = "."
|
||||
// DefaultClass is the class supported by the system
|
||||
DefaultClass = "IN"
|
||||
// ForwarderClientPort is the port clients connect to. DNAT rewrites packets from ForwarderClientPort to ForwarderServerPort.
|
||||
ForwarderClientPort uint16 = 5353
|
||||
// ForwarderServerPort is the port the DNS forwarder actually listens on. Packets to ForwarderClientPort are DNATed here.
|
||||
ForwarderServerPort uint16 = 22054
|
||||
)
|
||||
|
||||
const invalidHostLabel = "[^a-zA-Z0-9-]+"
|
||||
@@ -31,6 +35,8 @@ type Config struct {
|
||||
NameServerGroups []*NameServerGroup
|
||||
// CustomZones contains a list of custom zone
|
||||
CustomZones []CustomZone
|
||||
// ForwarderPort is the port clients should connect to on routing peers for DNS forwarding
|
||||
ForwarderPort uint16
|
||||
}
|
||||
|
||||
// CustomZone represents a custom zone to be resolved by the dns server
|
||||
|
||||
4
go.mod
4
go.mod
@@ -56,13 +56,14 @@ require (
|
||||
github.com/hashicorp/go-multierror v1.1.1
|
||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
||||
github.com/hashicorp/go-version v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.5.5
|
||||
github.com/libdns/route53 v1.5.0
|
||||
github.com/libp2p/go-netroute v0.2.1
|
||||
github.com/mdlayher/socket v0.5.1
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/nadoo/ipset v0.5.0
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/oschwald/maxminddb-golang v1.12.0
|
||||
@@ -183,7 +184,6 @@ require (
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/pgx/v5 v5.5.5 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
||||
github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
|
||||
4
go.sum
4
go.sum
@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f h1:XIpRDlpPz3zFUkpwaqDRHjwpQRsf2ZKHggoex1MTafs=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 h1:moJbL1uuaWR35yUgHZ6suijjqqW8/qGCuPPBXu5MeWQ=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48/go.mod h1:ifKa2jGPsOzZhJFo72v2AE5nMP3GYvlhoZ9JV6lHlJ8=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||
|
||||
@@ -185,12 +185,15 @@ if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]; then
|
||||
echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME"
|
||||
echo ""
|
||||
|
||||
export NETBIRD_SIGNAL_PROTOCOL="https"
|
||||
unset NETBIRD_LETSENCRYPT_DOMAIN
|
||||
unset NETBIRD_MGMT_API_CERT_FILE
|
||||
unset NETBIRD_MGMT_API_CERT_KEY_FILE
|
||||
fi
|
||||
|
||||
if [[ -n "$NETBIRD_MGMT_API_CERT_FILE" && -n "$NETBIRD_MGMT_API_CERT_KEY_FILE" ]]; then
|
||||
export NETBIRD_SIGNAL_PROTOCOL="https"
|
||||
fi
|
||||
|
||||
# Check if management identity provider is set
|
||||
if [ -n "$NETBIRD_MGMT_IDP" ]; then
|
||||
EXTRA_CONFIG={}
|
||||
|
||||
@@ -40,13 +40,21 @@ services:
|
||||
signal:
|
||||
<<: *default
|
||||
image: netbirdio/signal:$NETBIRD_SIGNAL_TAG
|
||||
depends_on:
|
||||
- dashboard
|
||||
volumes:
|
||||
- $SIGNAL_VOLUMENAME:/var/lib/netbird
|
||||
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro
|
||||
ports:
|
||||
- $NETBIRD_SIGNAL_PORT:80
|
||||
# # port and command for Let's Encrypt validation
|
||||
# - 443:443
|
||||
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
|
||||
command: [
|
||||
"--cert-file", "$NETBIRD_MGMT_API_CERT_FILE",
|
||||
"--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE",
|
||||
"--log-file", "console"
|
||||
]
|
||||
|
||||
# Relay
|
||||
relay:
|
||||
|
||||
@@ -35,7 +35,13 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
|
||||
|
||||
func (s *BaseServer) PermissionsManager() permissions.Manager {
|
||||
return Create(s, func() permissions.Manager {
|
||||
return integrations.InitPermissionsManager(s.Store())
|
||||
manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter())
|
||||
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
manager.SetAccountManager(s.AccountManager())
|
||||
})
|
||||
|
||||
return manager
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/management/cmd"
|
||||
"log"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
|
||||
"github.com/netbirdio/netbird/management/cmd"
|
||||
)
|
||||
|
||||
func main() {
|
||||
go func() {
|
||||
log.Println(http.ListenAndServe("localhost:6060", nil))
|
||||
}()
|
||||
if err := cmd.Execute(); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@@ -53,6 +53,9 @@ const (
|
||||
peerSchedulerRetryInterval = 3 * time.Second
|
||||
emptyUserID = "empty user ID in claims"
|
||||
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
|
||||
|
||||
envNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
|
||||
envNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
|
||||
)
|
||||
|
||||
type userLoggedInOnce bool
|
||||
@@ -109,6 +112,11 @@ type DefaultAccountManager struct {
|
||||
loginFilter *loginFilter
|
||||
|
||||
disableDefaultPolicy bool
|
||||
|
||||
holder *types.Holder
|
||||
|
||||
expNewNetworkMap bool
|
||||
expNewNetworkMapAIDs map[string]struct{}
|
||||
}
|
||||
|
||||
func isUniqueConstraintError(err error) bool {
|
||||
@@ -196,6 +204,18 @@ func BuildManager(
|
||||
log.WithContext(ctx).Debugf("took %v to instantiate account manager", time.Since(start))
|
||||
}()
|
||||
|
||||
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(envNewNetworkMapBuilder))
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", envNewNetworkMapBuilder, err)
|
||||
newNetworkMapBuilder = false
|
||||
}
|
||||
|
||||
ids := strings.Split(os.Getenv(envNewNetworkMapAccounts), ",")
|
||||
expIDs := make(map[string]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
expIDs[id] = struct{}{}
|
||||
}
|
||||
|
||||
am := &DefaultAccountManager{
|
||||
Store: store,
|
||||
geo: geo,
|
||||
@@ -217,6 +237,10 @@ func BuildManager(
|
||||
permissionsManager: permissionsManager,
|
||||
loginFilter: newLoginFilter(),
|
||||
disableDefaultPolicy: disableDefaultPolicy,
|
||||
holder: types.NewHolder(),
|
||||
|
||||
expNewNetworkMap: newNetworkMapBuilder,
|
||||
expNewNetworkMapAIDs: expIDs,
|
||||
}
|
||||
|
||||
am.startWarmup(ctx)
|
||||
@@ -395,6 +419,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
}
|
||||
|
||||
if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -1477,6 +1504,10 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
}
|
||||
|
||||
if removedGroupAffectsPeers || newGroupsAffectsPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, userAuth.AccountId); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
|
||||
am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
|
||||
}
|
||||
@@ -2129,6 +2160,11 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
if updateNetworkMap {
|
||||
peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -109,7 +109,7 @@ type Manager interface {
|
||||
GetIdpManager() idp.Manager
|
||||
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
|
||||
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error)
|
||||
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
|
||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
|
||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
@@ -128,4 +128,5 @@ type Manager interface {
|
||||
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||
SetEphemeralManager(em ephemeral.Manager)
|
||||
AllowSync(string, uint64) bool
|
||||
RecalculateNetworkMapCache(ctx context.Context, accountId string) error
|
||||
}
|
||||
|
||||
@@ -1154,7 +1154,16 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) {
|
||||
t.Setenv(envNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_SaveGroup(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_SaveGroup(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
group := types.Group{
|
||||
@@ -1181,7 +1190,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
updChan := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
@@ -1189,7 +1198,12 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
message, _, _, err := updChan.NetworkMap.Pop(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("timeout waiting for network update")
|
||||
}
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 2 {
|
||||
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
|
||||
@@ -1205,18 +1219,26 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) {
|
||||
t.Setenv(envNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeletePolicy(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeletePolicy(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
manager, account, peer1, _, _ := setupNetworkMapTest(t)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
// Ensure that we do not receive an update message before the policy is deleted
|
||||
time.Sleep(time.Second)
|
||||
select {
|
||||
case <-updMsg:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
_, _, _, err := updMsg.NetworkMap.Pop(ctx)
|
||||
if err != nil {
|
||||
t.Logf("received addPeer update message before policy deletion")
|
||||
default:
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
@@ -1224,7 +1246,12 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
message, _, _, err := updMsg.NetworkMap.Pop(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("timeout waiting for network update")
|
||||
}
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 0 {
|
||||
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
|
||||
@@ -1239,7 +1266,16 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) {
|
||||
t.Setenv(envNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_SavePolicy(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_SavePolicy(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
manager, account, peer1, peer2, _ := setupNetworkMapTest(t)
|
||||
|
||||
group := types.Group{
|
||||
@@ -1261,7 +1297,12 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
message, _, _, err := updMsg.NetworkMap.Pop(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("timeout waiting for network update")
|
||||
}
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 2 {
|
||||
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
|
||||
@@ -1288,7 +1329,16 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) {
|
||||
t.Setenv(envNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeletePeer(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeletePeer(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
manager, account, peer1, _, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
group := types.Group{
|
||||
@@ -1326,7 +1376,12 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
message, _, _, err := updMsg.NetworkMap.Pop(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("timeout waiting for network update")
|
||||
}
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 1 {
|
||||
t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers))
|
||||
@@ -1341,7 +1396,16 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) {
|
||||
t.Setenv(envNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeleteGroup(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeleteGroup(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
@@ -1382,7 +1446,12 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
message, _, _, err := updMsg.NetworkMap.Pop(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("timeout waiting for network update")
|
||||
}
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 0 {
|
||||
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
|
||||
@@ -1736,6 +1805,7 @@ func TestAccount_Copy(t *testing.T) {
|
||||
Address: "172.12.6.1/24",
|
||||
},
|
||||
},
|
||||
NetworkMapCache: &types.NetworkMapBuilder{},
|
||||
}
|
||||
err := hasNilField(account)
|
||||
if err != nil {
|
||||
@@ -2893,7 +2963,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) {
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(metrics), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -2974,27 +3044,33 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account,
|
||||
return manager, account, peer1, peer2, peer3
|
||||
}
|
||||
|
||||
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) {
|
||||
func peerShouldNotReceiveUpdate(t *testing.T, updateMessageBuffer *UpdateBuffer) {
|
||||
t.Helper()
|
||||
select {
|
||||
case msg := <-updateMessage:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
msg, _, _, _ := updateMessageBuffer.Pop(ctx)
|
||||
if msg != nil {
|
||||
t.Errorf("Unexpected message received: %+v", msg)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) {
|
||||
func peerShouldReceiveUpdate(t *testing.T, updateMessageBuffer *UpdateBuffer) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case msg := <-updateMessage:
|
||||
if msg == nil {
|
||||
t.Errorf("Received nil update message, expected valid message")
|
||||
}
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Error("Timed out waiting for update message")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
msg, _, _, err := updateMessageBuffer.Pop(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Expected update message, but none received")
|
||||
}
|
||||
if msg == nil {
|
||||
t.Errorf("Received nil update message, expected valid message")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func BenchmarkSyncAndMarkPeer(b *testing.B) {
|
||||
@@ -3031,11 +3107,13 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
peerChannels := make(map[string]*UpdateChannel)
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
peerChannels[peerID] = &UpdateChannel{
|
||||
Important: make(chan *UpdateMessage, channelBufferSize),
|
||||
NetworkMap: NewUpdateBuffer(manager.metrics.UpdateChannelMetrics()),
|
||||
}
|
||||
}
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
@@ -3094,9 +3172,12 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
peerChannels := make(map[string]*UpdateChannel)
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
peerChannels[peerID] = &UpdateChannel{
|
||||
Important: make(chan *UpdateMessage, channelBufferSize),
|
||||
NetworkMap: NewUpdateBuffer(manager.metrics.UpdateChannelMetrics()),
|
||||
}
|
||||
}
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
|
||||
@@ -3164,9 +3245,12 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
peerChannels := make(map[string]*UpdateChannel)
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
peerChannels[peerID] = &UpdateChannel{
|
||||
Important: make(chan *UpdateMessage, channelBufferSize),
|
||||
NetworkMap: NewUpdateBuffer(manager.metrics.UpdateChannelMetrics()),
|
||||
}
|
||||
}
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
|
||||
|
||||
@@ -231,7 +231,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) {
|
||||
return fmt.Sprintf("%s/%s", audience, name)
|
||||
}
|
||||
|
||||
lastLogin := time.Date(2025, 2, 12, 14, 25, 26, 0, time.UTC) //"2025-02-12T14:25:26.186Z"
|
||||
lastLogin := time.Date(2025, 2, 12, 14, 25, 26, 0, time.UTC) // "2025-02-12T14:25:26.186Z"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -21,8 +21,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
dnsForwarderPort = 22054
|
||||
oldForwarderPort = 5353
|
||||
dnsForwarderPort = nbdns.ForwarderServerPort
|
||||
oldForwarderPort = nbdns.ForwarderClientPort
|
||||
)
|
||||
|
||||
const dnsForwarderPortMinVersion = "v0.59.0"
|
||||
@@ -117,6 +117,9 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -196,7 +199,7 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID
|
||||
// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0.
|
||||
func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
|
||||
if len(peers) == 0 {
|
||||
return oldForwarderPort
|
||||
return int64(oldForwarderPort)
|
||||
}
|
||||
|
||||
reqVer := semver.Canonical(requiredVersion)
|
||||
@@ -211,17 +214,17 @@ func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
|
||||
peerVersion := semver.Canonical("v" + peer.Meta.WtVersion)
|
||||
if peerVersion == "" {
|
||||
// If any peer doesn't have version info, return 0
|
||||
return oldForwarderPort
|
||||
return int64(oldForwarderPort)
|
||||
}
|
||||
|
||||
// Compare versions
|
||||
if semver.Compare(peerVersion, reqVer) < 0 {
|
||||
return oldForwarderPort
|
||||
return int64(oldForwarderPort)
|
||||
}
|
||||
}
|
||||
|
||||
// All peers have the required version or newer
|
||||
return dnsForwarderPort
|
||||
return int64(dnsForwarderPort)
|
||||
}
|
||||
|
||||
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
||||
|
||||
@@ -394,7 +394,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
toProtocolDNSConfig(testData, cache, dnsForwarderPort)
|
||||
toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort))
|
||||
}
|
||||
})
|
||||
|
||||
@@ -402,7 +402,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache := &DNSConfigCache{}
|
||||
toProtocolDNSConfig(testData, cache, dnsForwarderPort)
|
||||
toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort))
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -455,13 +455,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
}
|
||||
|
||||
// First run with config1
|
||||
result1 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort)
|
||||
result1 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort))
|
||||
|
||||
// Second run with config2
|
||||
result2 := toProtocolDNSConfig(config2, &cache, dnsForwarderPort)
|
||||
result2 := toProtocolDNSConfig(config2, &cache, int64(dnsForwarderPort))
|
||||
|
||||
// Third run with config1 again
|
||||
result3 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort)
|
||||
result3 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort))
|
||||
|
||||
// Verify that result1 and result3 are identical
|
||||
if !reflect.DeepEqual(result1, result3) {
|
||||
@@ -486,7 +486,7 @@ func TestComputeForwarderPort(t *testing.T) {
|
||||
// Test with empty peers list
|
||||
peers := []*nbpeer.Peer{}
|
||||
result := computeForwarderPort(peers, "v0.59.0")
|
||||
if result != oldForwarderPort {
|
||||
if result != int64(oldForwarderPort) {
|
||||
t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result)
|
||||
}
|
||||
|
||||
@@ -504,7 +504,7 @@ func TestComputeForwarderPort(t *testing.T) {
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != oldForwarderPort {
|
||||
if result != int64(oldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result)
|
||||
}
|
||||
|
||||
@@ -522,7 +522,7 @@ func TestComputeForwarderPort(t *testing.T) {
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != dnsForwarderPort {
|
||||
if result != int64(dnsForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result)
|
||||
}
|
||||
|
||||
@@ -540,7 +540,7 @@ func TestComputeForwarderPort(t *testing.T) {
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != oldForwarderPort {
|
||||
if result != int64(oldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result)
|
||||
}
|
||||
|
||||
@@ -553,7 +553,7 @@ func TestComputeForwarderPort(t *testing.T) {
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != oldForwarderPort {
|
||||
if result != int64(oldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result)
|
||||
}
|
||||
|
||||
@@ -565,7 +565,7 @@ func TestComputeForwarderPort(t *testing.T) {
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result == oldForwarderPort {
|
||||
if result == int64(oldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result)
|
||||
}
|
||||
|
||||
@@ -578,7 +578,7 @@ func TestComputeForwarderPort(t *testing.T) {
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != oldForwarderPort {
|
||||
if result != int64(oldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result)
|
||||
}
|
||||
}
|
||||
@@ -609,7 +609,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving dns setting with unused groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -629,7 +629,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("creating dns setting with unused groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -662,7 +662,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -688,7 +688,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving dns setting with used groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -708,7 +708,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("removing group with no peers from dns settings", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -728,7 +728,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("removing group with peers from dns settings", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
@@ -114,6 +114,9 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -182,6 +185,9 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -250,6 +256,9 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -318,6 +327,9 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -471,6 +483,9 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -509,6 +524,9 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -537,6 +555,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -575,6 +596,9 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
|
||||
@@ -451,7 +451,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -474,7 +474,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("adding peer to unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -493,7 +493,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("removing peer from unliked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -511,7 +511,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -544,7 +544,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving linked group to policy", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -566,7 +566,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("adding peer to linked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -584,7 +584,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("removing peer from linked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -613,7 +613,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -654,7 +654,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -681,7 +681,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -728,7 +728,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
@@ -7,8 +7,10 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
pb "github.com/golang/protobuf/proto" // nolint
|
||||
@@ -44,6 +46,9 @@ import (
|
||||
const (
|
||||
envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS"
|
||||
envBlockPeers = "NB_BLOCK_SAME_PEERS"
|
||||
envConcurrentSyncs = "NB_MAX_CONCURRENT_SYNCS"
|
||||
|
||||
defaultSyncLim = 1000
|
||||
)
|
||||
|
||||
// GRPCServer an instance of a Management gRPC API server
|
||||
@@ -63,6 +68,10 @@ type GRPCServer struct {
|
||||
logBlockedPeers bool
|
||||
blockPeersWithSameConfig bool
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
debounce int
|
||||
|
||||
syncSem atomic.Int32
|
||||
syncLim int32
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
@@ -86,7 +95,7 @@ func NewServer(
|
||||
if appMetrics != nil {
|
||||
// update gauge based on number of connected peers which is equal to open gRPC streams
|
||||
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
|
||||
return int64(len(peersUpdateManager.peerChannels))
|
||||
return int64(peersUpdateManager.GetChannelCount())
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -96,6 +105,24 @@ func NewServer(
|
||||
logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true"
|
||||
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
|
||||
|
||||
debounce := 1
|
||||
if decounceStr := os.Getenv("NB_UPDATE_BUFFER_DEBOUNCE_S"); decounceStr != "" {
|
||||
debounce, err = strconv.Atoi(decounceStr)
|
||||
if err != nil {
|
||||
log.Errorf("invalid value for NB_UPDATE_BUFFER_DEBOUNCE_S: %v using 1 second", err)
|
||||
}
|
||||
}
|
||||
|
||||
syncLim := int32(defaultSyncLim)
|
||||
if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" {
|
||||
syncLimParsed, err := strconv.Atoi(syncLimStr)
|
||||
if err != nil {
|
||||
log.Errorf("invalid value for %s: %v using %d", envConcurrentSyncs, err, defaultSyncLim)
|
||||
} else {
|
||||
syncLim = int32(syncLimParsed)
|
||||
}
|
||||
}
|
||||
|
||||
return &GRPCServer{
|
||||
wgKey: key,
|
||||
// peerKey -> event channel
|
||||
@@ -110,6 +137,9 @@ func NewServer(
|
||||
logBlockedPeers: logBlockedPeers,
|
||||
blockPeersWithSameConfig: blockPeersWithSameConfig,
|
||||
integratedPeerValidator: integratedPeerValidator,
|
||||
debounce: debounce,
|
||||
|
||||
syncLim: syncLim,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -151,6 +181,11 @@ func getRealIP(ctx context.Context) net.IP {
|
||||
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
||||
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
||||
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
||||
if s.syncSem.Load() >= s.syncLim {
|
||||
return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later")
|
||||
}
|
||||
s.syncSem.Add(1)
|
||||
|
||||
reqStart := time.Now()
|
||||
|
||||
ctx := srv.Context()
|
||||
@@ -158,6 +193,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
syncReq := &proto.SyncRequest{}
|
||||
peerKey, err := s.parseRequest(ctx, req, syncReq)
|
||||
if err != nil {
|
||||
s.syncSem.Add(-1)
|
||||
return err
|
||||
}
|
||||
realIP := getRealIP(ctx)
|
||||
@@ -172,6 +208,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
|
||||
}
|
||||
if s.blockPeersWithSameConfig {
|
||||
s.syncSem.Add(-1)
|
||||
return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn)
|
||||
}
|
||||
}
|
||||
@@ -196,8 +233,10 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN")
|
||||
log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String())
|
||||
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
|
||||
s.syncSem.Add(-1)
|
||||
return status.Errorf(codes.PermissionDenied, "peer is not registered")
|
||||
}
|
||||
s.syncSem.Add(-1)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -213,12 +252,14 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
return mapError(ctx, err)
|
||||
}
|
||||
|
||||
err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -237,18 +278,62 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
|
||||
log.WithContext(ctx).Debugf("Sync: took %v", time.Since(reqStart))
|
||||
|
||||
s.syncSem.Add(-1)
|
||||
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
||||
}
|
||||
|
||||
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
||||
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates *UpdateChannel, srv proto.ManagementService_SyncServer) error {
|
||||
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
|
||||
|
||||
// Channel to receive network map updates with metadata
|
||||
type networkMapUpdate struct {
|
||||
msg *UpdateMessage
|
||||
overwrites int
|
||||
timeSinceLastPop time.Duration
|
||||
}
|
||||
networkMapCh := make(chan networkMapUpdate)
|
||||
|
||||
// Start goroutine to Pop from buffer
|
||||
go func() {
|
||||
for {
|
||||
start := time.Now()
|
||||
update, overwrites, timeSinceLastPop, err := updates.NetworkMap.Pop(ctx)
|
||||
log.WithContext(ctx).Debugf("popped an update for peer %s from the network map buffer in %v (overwrites: %d)", peerKey.String(), time.Since(start), overwrites)
|
||||
if err != nil {
|
||||
close(networkMapCh)
|
||||
return
|
||||
}
|
||||
// start := time.Now()
|
||||
select {
|
||||
case networkMapCh <- networkMapUpdate{
|
||||
msg: update,
|
||||
overwrites: overwrites,
|
||||
timeSinceLastPop: timeSinceLastPop,
|
||||
}:
|
||||
// log.WithContext(ctx).Debugf("forwarded an update for peer %s from the network map buffer in %v (overwrites: %d)", peerKey.String(), time.Since(start), overwrites)
|
||||
|
||||
// Adaptive debounce: increase delay based on overwrite rate
|
||||
// If many overwrites happened, wait longer to let things settle
|
||||
debounce := s.calculateDebounce(overwrites, timeSinceLastPop)
|
||||
if debounce > 0 {
|
||||
// start = time.Now()
|
||||
time.Sleep(debounce)
|
||||
// log.WithContext(ctx).Debugf("debounced for %v for peer %s (overwrites: %d)", time.Since(start), peerKey.String(), overwrites)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
// condition when there are some updates
|
||||
case update, open := <-updates:
|
||||
case update, open := <-updates.Important:
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1)
|
||||
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates.Important) + 1)
|
||||
}
|
||||
|
||||
if !open {
|
||||
@@ -256,13 +341,24 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return nil
|
||||
}
|
||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
||||
// log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
||||
|
||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
|
||||
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
|
||||
return err
|
||||
}
|
||||
|
||||
case updateData, ok := <-networkMapCh:
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("update buffer for peer %s closed", peerKey.String())
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return nil
|
||||
}
|
||||
// log.WithContext(ctx).Debugf("sending latest update to peer %s (overwrites: %d)", peerKey.String(), updateData.overwrites)
|
||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, updateData.msg, srv); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// condition when client <-> server connection has been terminated
|
||||
case <-srv.Context().Done():
|
||||
// happens when connection drops, e.g. client disconnects
|
||||
@@ -273,6 +369,33 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe
|
||||
}
|
||||
}
|
||||
|
||||
// calculateDebounce calculates adaptive debounce duration based on overwrite rate
|
||||
// More overwrites = longer debounce to let updates settle
|
||||
func (s *GRPCServer) calculateDebounce(overwrites int, timeSinceLastPop time.Duration) time.Duration {
|
||||
if overwrites == 0 {
|
||||
// No overwrites, use base debounce
|
||||
return time.Duration(s.debounce) * time.Second
|
||||
}
|
||||
|
||||
var rate float64
|
||||
if timeSinceLastPop > 0 {
|
||||
rate = float64(overwrites) / timeSinceLastPop.Seconds()
|
||||
} else {
|
||||
rate = float64(overwrites)
|
||||
}
|
||||
|
||||
multiplier := 2.0
|
||||
adaptiveDebounce := float64(s.debounce) + (rate * multiplier)
|
||||
|
||||
// Cap at 10x base debounce
|
||||
maxDebounce := float64(s.debounce) * 10
|
||||
if adaptiveDebounce > maxDebounce {
|
||||
adaptiveDebounce = maxDebounce
|
||||
}
|
||||
|
||||
return time.Duration(adaptiveDebounce * float64(time.Second))
|
||||
}
|
||||
|
||||
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
||||
// then sends the encrypted message to the connected peer via the sync server.
|
||||
func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
|
||||
39
management/server/holder.go
Normal file
39
management/server/holder.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func (am *DefaultAccountManager) enrichAccountFromHolder(account *types.Account) {
|
||||
a := am.holder.GetAccount(account.Id)
|
||||
if a == nil {
|
||||
am.holder.AddAccount(account)
|
||||
return
|
||||
}
|
||||
account.NetworkMapCache = a.NetworkMapCache
|
||||
if account.NetworkMapCache == nil {
|
||||
return
|
||||
}
|
||||
account.NetworkMapCache.UpdateAccountPointer(account)
|
||||
am.holder.AddAccount(account)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getAccountFromHolder(accountID string) *types.Account {
|
||||
return am.holder.GetAccount(accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getAccountFromHolderOrInit(accountID string) *types.Account {
|
||||
a := am.holder.GetAccount(accountID)
|
||||
if a != nil {
|
||||
return a
|
||||
}
|
||||
account, err := am.holder.LoadOrStoreFunc(accountID, am.requestBuffer.GetAccountWithBackpressure)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return account
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) updateAccountInHolder(account *types.Account) {
|
||||
am.holder.AddAccount(account)
|
||||
}
|
||||
@@ -78,7 +78,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
|
||||
grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID)
|
||||
grpsInfoMap := groups.ToGroupsInfoMap(grps, 0)
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
@@ -86,7 +86,9 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
|
||||
}
|
||||
|
||||
_, valid := validPeers[peer.ID]
|
||||
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid))
|
||||
reason := invalidPeers[peer.ID]
|
||||
|
||||
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid, reason))
|
||||
}
|
||||
|
||||
func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||
@@ -147,16 +149,17 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
|
||||
|
||||
grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0)
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
|
||||
log.WithContext(ctx).Errorf("failed to get validated peers: %v", err)
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
|
||||
_, valid := validPeers[peer.ID]
|
||||
reason := invalidPeers[peer.ID]
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid))
|
||||
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
|
||||
}
|
||||
|
||||
func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
|
||||
@@ -240,22 +243,25 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0))
|
||||
}
|
||||
|
||||
validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||
validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err)
|
||||
log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
h.setApprovalRequiredFlag(respBody, validPeersMap)
|
||||
h.setApprovalRequiredFlag(respBody, validPeersMap, invalidPeersMap)
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, respBody)
|
||||
}
|
||||
|
||||
func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) {
|
||||
func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersMap map[string]struct{}, invalidPeersMap map[string]string) {
|
||||
for _, peer := range respBody {
|
||||
_, ok := approvedPeersMap[peer.Id]
|
||||
_, ok := validPeersMap[peer.Id]
|
||||
if !ok {
|
||||
peer.ApprovalRequired = true
|
||||
|
||||
reason := invalidPeersMap[peer.Id]
|
||||
peer.DisapprovalReason = &reason
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -304,7 +310,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||
validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
@@ -430,13 +436,13 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
|
||||
}
|
||||
}
|
||||
|
||||
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer {
|
||||
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool, reason string) *api.Peer {
|
||||
osVersion := peer.Meta.OSVersion
|
||||
if osVersion == "" {
|
||||
osVersion = peer.Meta.Core
|
||||
}
|
||||
|
||||
return &api.Peer{
|
||||
apiPeer := &api.Peer{
|
||||
CreatedAt: peer.CreatedAt,
|
||||
Id: peer.ID,
|
||||
Name: peer.Name,
|
||||
@@ -465,6 +471,12 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
|
||||
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
||||
Ephemeral: peer.Ephemeral,
|
||||
}
|
||||
|
||||
if !approved {
|
||||
apiPeer.DisapprovalReason = &reason
|
||||
}
|
||||
|
||||
return apiPeer
|
||||
}
|
||||
|
||||
func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeersCount int) *api.PeerBatch {
|
||||
|
||||
@@ -7,9 +7,10 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
|
||||
@@ -88,7 +88,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
|
||||
func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) {
|
||||
var err error
|
||||
var groups []*types.Group
|
||||
var peers []*nbpeer.Peer
|
||||
@@ -96,20 +96,30 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI
|
||||
|
||||
groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra)
|
||||
validPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
invalidPeers, err := am.integratedPeerValidator.GetInvalidPeers(ctx, accountID, settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return validPeers, invalidPeers, nil
|
||||
}
|
||||
|
||||
type MockIntegratedValidator struct {
|
||||
@@ -136,6 +146,10 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID
|
||||
return validatedPeers, nil
|
||||
}
|
||||
|
||||
func (a MockIntegratedValidator) GetInvalidPeers(_ context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) {
|
||||
return make(map[string]string), nil
|
||||
}
|
||||
|
||||
func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer {
|
||||
return peer
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ type IntegratedValidator interface {
|
||||
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer
|
||||
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error)
|
||||
GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error)
|
||||
GetInvalidPeers(ctx context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error)
|
||||
PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error
|
||||
SetPeerInvalidationListener(fn func(accountID string, peerIDs []string))
|
||||
Stop(ctx context.Context)
|
||||
|
||||
@@ -125,9 +125,10 @@ type MockAccountManager struct {
|
||||
UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
||||
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
|
||||
|
||||
AllowSyncFunc func(string, uint64) bool
|
||||
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
AllowSyncFunc func(string, uint64) bool
|
||||
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
|
||||
@@ -189,17 +190,17 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
|
||||
func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) {
|
||||
account, err := am.GetAccountFunc(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
approvedPeers := make(map[string]struct{})
|
||||
for id := range account.Peers {
|
||||
approvedPeers[id] = struct{}{}
|
||||
}
|
||||
return approvedPeers, nil
|
||||
return approvedPeers, nil, nil
|
||||
}
|
||||
|
||||
// GetGroup mock implementation of GetGroup from server.AccountManager interface
|
||||
@@ -986,3 +987,10 @@ func (am *MockAccountManager) AllowSync(key string, hash uint64) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountID string) error {
|
||||
if am.RecalculateNetworkMapCacheFunc != nil {
|
||||
return am.RecalculateNetworkMapCacheFunc(ctx, accountID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -83,6 +83,9 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -134,6 +137,9 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -177,6 +183,9 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
|
||||
@@ -1004,7 +1004,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("creating nameserver group with distribution group no peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1031,7 +1031,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving nameserver group with distribution group no peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1049,7 +1049,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("creating nameserver group with distribution group has peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1075,7 +1075,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving nameserver group with distribution group has peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1105,7 +1105,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting nameserver group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
80
management/server/networkmap.go
Normal file
80
management/server/networkmap.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func (am *DefaultAccountManager) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
|
||||
am.enrichAccountFromHolder(account)
|
||||
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getPeerNetworkMapExp(
|
||||
ctx context.Context,
|
||||
accountId string,
|
||||
peerId string,
|
||||
validatedPeers map[string]struct{},
|
||||
customZone nbdns.CustomZone,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *types.NetworkMap {
|
||||
account := am.getAccountFromHolderOrInit(accountId)
|
||||
if account == nil {
|
||||
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
|
||||
return &types.NetworkMap{
|
||||
Network: &types.Network{},
|
||||
}
|
||||
}
|
||||
return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error {
|
||||
am.enrichAccountFromHolder(account)
|
||||
return account.OnPeerAddedUpdNetworkMapCache(peerId)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
|
||||
am.enrichAccountFromHolder(account)
|
||||
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) updatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
|
||||
account := am.getAccountFromHolder(accountId)
|
||||
if account == nil {
|
||||
return
|
||||
}
|
||||
account.UpdatePeerInNetworkMapCache(peer)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
|
||||
account.RecalculateNetworkMapCache(validatedPeers)
|
||||
am.updateAccountInHolder(account)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
|
||||
if am.experimentalNetworkMap(accountId) {
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
|
||||
return err
|
||||
}
|
||||
am.recalculateNetworkMapCache(account, validatedPeers)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) experimentalNetworkMap(accountId string) bool {
|
||||
_, ok := am.expNewNetworkMapAIDs[accountId]
|
||||
return am.expNewNetworkMap || ok
|
||||
}
|
||||
@@ -177,6 +177,9 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
|
||||
event()
|
||||
}
|
||||
|
||||
if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
|
||||
@@ -157,6 +157,9 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
|
||||
event()
|
||||
}
|
||||
|
||||
if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
|
||||
|
||||
return resource, nil
|
||||
@@ -257,6 +260,9 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
event()
|
||||
}
|
||||
|
||||
if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
|
||||
|
||||
return resource, nil
|
||||
@@ -331,6 +337,9 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
|
||||
event()
|
||||
}
|
||||
|
||||
if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
|
||||
@@ -119,6 +119,9 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network))
|
||||
|
||||
if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID)
|
||||
|
||||
return router, nil
|
||||
@@ -183,6 +186,9 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network))
|
||||
|
||||
if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID)
|
||||
|
||||
return router, nil
|
||||
@@ -217,6 +223,9 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
|
||||
|
||||
event()
|
||||
|
||||
if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
b64 "encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -145,6 +146,9 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
}
|
||||
|
||||
if expired {
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
|
||||
}
|
||||
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
|
||||
// the expired one. Here we notify them that connection is now allowed again.
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
@@ -321,6 +325,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
}
|
||||
}
|
||||
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
|
||||
}
|
||||
|
||||
if peerLabelChanged || requiresPeerUpdates {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
} else if sshChanged {
|
||||
@@ -381,6 +389,18 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := am.onPeerDeletedUpdNetworkMapCache(account, peerID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if userID != activity.SystemInitiator {
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
@@ -417,7 +437,13 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
|
||||
return nil, err
|
||||
}
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
||||
var networkMap *types.NetworkMap
|
||||
|
||||
if am.experimentalNetworkMap(peer.AccountID) {
|
||||
networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
|
||||
} else {
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
@@ -690,6 +716,17 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
|
||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if err := am.onPeerAddedUpdNetworkMapCache(account, newPeer.ID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
||||
@@ -776,6 +813,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
}
|
||||
|
||||
if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) {
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
|
||||
}
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -901,6 +941,9 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
}
|
||||
|
||||
if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) {
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
|
||||
}
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -1014,9 +1057,17 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
||||
return peer, emptyMap, nil, nil
|
||||
}
|
||||
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
var (
|
||||
account *types.Account
|
||||
err error
|
||||
)
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
account = am.getAccountFromHolderOrInit(accountID)
|
||||
} else {
|
||||
account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
@@ -1037,7 +1088,13 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics())
|
||||
var networkMap *types.NetworkMap
|
||||
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics())
|
||||
} else {
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics())
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
@@ -1167,11 +1224,18 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err)
|
||||
return
|
||||
var (
|
||||
account *types.Account
|
||||
err error
|
||||
)
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
account = am.getAccountFromHolderOrInit(accountID)
|
||||
} else {
|
||||
account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
globalStart := time.Now()
|
||||
@@ -1196,7 +1260,6 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, 10)
|
||||
|
||||
dnsCache := &DNSConfigCache{}
|
||||
dnsDomain := am.GetDNSDomain(account.Settings)
|
||||
@@ -1204,6 +1267,10 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
am.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
|
||||
}
|
||||
|
||||
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
@@ -1218,51 +1285,64 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion)
|
||||
|
||||
processNetworkMap := func(p *nbpeer.Peer) {
|
||||
start := time.Now()
|
||||
|
||||
postureChecks, err := am.getPeerPostureChecks(account, p.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", p.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
am.metrics.UpdateChannelMetrics().CountCalcPostureChecksDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics())
|
||||
} else {
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
|
||||
}
|
||||
|
||||
am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start))
|
||||
|
||||
peerGroups := account.GetPeerGroups(p.ID)
|
||||
start = time.Now()
|
||||
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
|
||||
|
||||
am.peersUpdateManager.SendNetworkMapUpdate(ctx, p.ID, &UpdateMessage{Update: update})
|
||||
}
|
||||
|
||||
numWorkers := runtime.NumCPU()
|
||||
wch := make(chan *nbpeer.Peer, numWorkers)
|
||||
wg.Add(numWorkers)
|
||||
for range numWorkers {
|
||||
go func() {
|
||||
for peer := range wch {
|
||||
processNetworkMap(peer)
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
for _, peer := range account.Peers {
|
||||
if !am.peersUpdateManager.HasChannel(peer.ID) {
|
||||
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{}
|
||||
go func(p *nbpeer.Peer) {
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
start := time.Now()
|
||||
|
||||
postureChecks, err := am.getPeerPostureChecks(account, p.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", peer.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
am.metrics.UpdateChannelMetrics().CountCalcPostureChecksDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
|
||||
|
||||
am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start))
|
||||
|
||||
peerGroups := account.GetPeerGroups(p.ID)
|
||||
start = time.Now()
|
||||
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
|
||||
|
||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||
}(peer)
|
||||
wch <- peer
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
close(wch)
|
||||
wg.Wait()
|
||||
if am.metrics != nil {
|
||||
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(globalStart))
|
||||
@@ -1270,9 +1350,11 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
}
|
||||
|
||||
type bufferUpdate struct {
|
||||
mu sync.Mutex
|
||||
next *time.Timer
|
||||
update atomic.Bool
|
||||
mu sync.Mutex
|
||||
next *time.Timer
|
||||
update atomic.Bool
|
||||
requestCount atomic.Int32
|
||||
lastResetTime atomic.Int64 // Unix timestamp in milliseconds
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
@@ -1283,6 +1365,7 @@ func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, a
|
||||
|
||||
if !b.mu.TryLock() {
|
||||
b.update.Store(true)
|
||||
b.requestCount.Add(1)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1292,18 +1375,41 @@ func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, a
|
||||
|
||||
go func() {
|
||||
defer b.mu.Unlock()
|
||||
|
||||
start := time.Now()
|
||||
|
||||
b.requestCount.Store(0)
|
||||
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
if !b.update.Load() {
|
||||
return
|
||||
}
|
||||
b.update.Store(false)
|
||||
|
||||
requestsDuringProcessing := b.requestCount.Load()
|
||||
executionTime := time.Since(start)
|
||||
|
||||
requestsPerMinute := float64(requestsDuringProcessing) / executionTime.Seconds() * 60
|
||||
|
||||
adaptiveDelay := time.Duration(requestsPerMinute) * 250 * time.Millisecond
|
||||
|
||||
// cap the maximum delay to avoid too long delays
|
||||
maxDelay := 30 * time.Second
|
||||
if adaptiveDelay > maxDelay {
|
||||
adaptiveDelay = maxDelay
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("adaptive debounce for account %s: %d requests during %v execution, waiting %v",
|
||||
accountID, requestsDuringProcessing, executionTime, adaptiveDelay)
|
||||
|
||||
if b.next == nil {
|
||||
b.next = time.AfterFunc(time.Duration(am.updateAccountPeersBufferInterval.Load()), func() {
|
||||
b.next = time.AfterFunc(adaptiveDelay, func() {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
})
|
||||
return
|
||||
}
|
||||
b.next.Reset(time.Duration(am.updateAccountPeersBufferInterval.Load()))
|
||||
b.next.Reset(adaptiveDelay)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -1351,7 +1457,13 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
|
||||
return
|
||||
}
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
if am.experimentalNetworkMap(accountId) {
|
||||
remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics())
|
||||
} else {
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
@@ -1368,7 +1480,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion)
|
||||
|
||||
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||
am.peersUpdateManager.SendNetworkMapUpdate(ctx, peer.ID, &UpdateMessage{Update: update})
|
||||
}
|
||||
|
||||
// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
|
||||
@@ -1565,7 +1677,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{
|
||||
am.peersUpdateManager.SendImportantUpdate(ctx, peer.ID, &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
RemotePeers: []*proto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: true,
|
||||
@@ -1580,7 +1692,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
|
||||
},
|
||||
},
|
||||
},
|
||||
NetworkMap: &types.NetworkMap{},
|
||||
})
|
||||
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
peerDeletedEvents = append(peerDeletedEvents, func() {
|
||||
|
||||
@@ -26,7 +26,6 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
@@ -168,6 +167,15 @@ func TestPeer_SessionExpired(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountManager_GetNetworkMap(t *testing.T) {
|
||||
testGetNetworkMapGeneral(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) {
|
||||
t.Setenv(envNewNetworkMapBuilder, "true")
|
||||
testGetNetworkMapGeneral(t)
|
||||
}
|
||||
|
||||
func testGetNetworkMapGeneral(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -945,14 +953,15 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||
minMsPerOpCICD float64
|
||||
maxMsPerOpCICD float64
|
||||
}{
|
||||
{"Small", 50, 5, 90, 120, 90, 120},
|
||||
{"Medium", 500, 100, 110, 150, 120, 260},
|
||||
{"Large", 5000, 200, 800, 1700, 2500, 5000},
|
||||
{"Small single", 50, 10, 90, 120, 90, 120},
|
||||
{"Medium single", 500, 10, 110, 170, 120, 200},
|
||||
{"Large 5", 5000, 15, 1300, 2100, 4900, 7000},
|
||||
{"Extra Large", 2000, 2000, 1300, 2400, 3000, 6400},
|
||||
{"Small", 350, 5, 90, 120, 90, 120},
|
||||
// {"Medium", 500, 100, 110, 150, 120, 260},
|
||||
// {"Large", 5000, 2, 800, 1700, 2500, 5000},
|
||||
// {"Small single", 50, 10, 90, 120, 90, 120},
|
||||
// {"Medium single", 500, 10, 110, 170, 120, 200},
|
||||
// {"Large 5", 5000, 15, 1300, 2100, 4900, 7000},
|
||||
// {"Extra Large", 2000, 2000, 1300, 2400, 3000, 6400},
|
||||
}
|
||||
b.Setenv("NB_EXPERIMENT_NETWORK_MAP", "false")
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
@@ -971,39 +980,49 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||
b.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
peerChannels := make(map[string]*UpdateChannel)
|
||||
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
peerChannels[peerID] = &UpdateChannel{
|
||||
Important: make(chan *UpdateMessage, channelBufferSize),
|
||||
NetworkMap: NewUpdateBuffer(manager.metrics.UpdateChannelMetrics()),
|
||||
}
|
||||
}
|
||||
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
for i := 0; i < 1000; i++ {
|
||||
b.Logf("Run %d", i)
|
||||
manager.UpdateAccountPeers(ctx, account.Id)
|
||||
mm(b)
|
||||
manager.Store.IncrementNetworkSerial(ctx, account.Id)
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.UpdateAccountPeers(ctx, account.Id)
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
|
||||
b.ReportMetric(msPerOp, "ms/op")
|
||||
|
||||
maxExpected := bc.maxMsPerOpLocal
|
||||
if os.Getenv("CI") == "true" {
|
||||
maxExpected = bc.maxMsPerOpCICD
|
||||
testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "newPeer")
|
||||
}
|
||||
|
||||
if msPerOp > maxExpected {
|
||||
b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mm(b *testing.B) {
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
// runtime.GC()
|
||||
b.Logf("Alloc: %v MB, TotalAlloc: %v MB, Sys: %v MB, NumGC: %v", m.Alloc/1024/1024, m.TotalAlloc/1024/1024, m.Sys/1024/1024, m.NumGC)
|
||||
}
|
||||
|
||||
func TestUpdateAccountPeers_Experimental(t *testing.T) {
|
||||
t.Setenv(envNewNetworkMapBuilder, "true")
|
||||
testUpdateAccountPeers(t)
|
||||
}
|
||||
|
||||
func TestUpdateAccountPeers(t *testing.T) {
|
||||
testUpdateAccountPeers(t)
|
||||
}
|
||||
|
||||
func testUpdateAccountPeers(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
peers int
|
||||
@@ -1026,25 +1045,33 @@ func TestUpdateAccountPeers(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create metrics: %v", err)
|
||||
}
|
||||
|
||||
account, err := manager.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
peerChannels := make(map[string]*UpdateChannel)
|
||||
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
peerChannels[peerID] = &UpdateChannel{
|
||||
Important: make(chan *UpdateMessage, channelBufferSize),
|
||||
NetworkMap: NewUpdateBuffer(metrics.UpdateChannelMetrics()),
|
||||
}
|
||||
}
|
||||
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
manager.UpdateAccountPeers(ctx, account.Id)
|
||||
|
||||
for _, channel := range peerChannels {
|
||||
update := <-channel
|
||||
update, _, _, _ := channel.NetworkMap.Pop(ctx)
|
||||
assert.Nil(t, update.Update.NetbirdConfig)
|
||||
assert.Equal(t, tc.peers, len(update.NetworkMap.Peers))
|
||||
assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules))
|
||||
assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers))
|
||||
assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules))
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1161,7 +1188,7 @@ func TestToSyncResponse(t *testing.T) {
|
||||
}
|
||||
dnsCache := &DNSConfigCache{}
|
||||
accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true}
|
||||
response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, dnsForwarderPort)
|
||||
response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort))
|
||||
|
||||
assert.NotNil(t, response)
|
||||
// assert peer config
|
||||
@@ -1548,6 +1575,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_LoginPeer(t *testing.T) {
|
||||
t.Setenv(envNewNetworkMapBuilder, "true")
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
@@ -1772,7 +1800,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1790,7 +1818,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("adding peer to unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg) //
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap) //
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1815,7 +1843,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting peer with unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1833,7 +1861,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("updating peer label", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1856,7 +1884,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
manager.integratedPeerValidator = MockIntegratedValidator{ValidatePeerFunc: requireUpdateFunc}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1878,7 +1906,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
manager.integratedPeerValidator = MockIntegratedValidator{ValidatePeerFunc: requireNoUpdateFunc}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1911,7 +1939,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1937,7 +1965,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting peer with linked group to policy", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1975,7 +2003,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -2001,7 +2029,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting peer with linked group to route", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -2030,7 +2058,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -2056,7 +2084,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting peer with linked group to route", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
@@ -22,6 +23,7 @@ type Manager interface {
|
||||
ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error
|
||||
|
||||
GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error)
|
||||
SetAccountManager(accountManager account.Manager)
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
@@ -121,3 +123,7 @@ func (m *managerImpl) GetPermissionsByRole(ctx context.Context, role types.UserR
|
||||
|
||||
return permissions, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) SetAccountManager(accountManager account.Manager) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
account "github.com/netbirdio/netbird/management/server/account"
|
||||
modules "github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
operations "github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
roles "github.com/netbirdio/netbird/management/server/permissions/roles"
|
||||
@@ -53,6 +54,18 @@ func (mr *MockManagerMockRecorder) GetPermissionsByRole(ctx, role interface{}) *
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPermissionsByRole", reflect.TypeOf((*MockManager)(nil).GetPermissionsByRole), ctx, role)
|
||||
}
|
||||
|
||||
// SetAccountManager mocks base method.
|
||||
func (m *MockManager) SetAccountManager(accountManager account.Manager) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetAccountManager", accountManager)
|
||||
}
|
||||
|
||||
// SetAccountManager indicates an expected call of SetAccountManager.
|
||||
func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccountManager", reflect.TypeOf((*MockManager)(nil).SetAccountManager), accountManager)
|
||||
}
|
||||
|
||||
// ValidateAccountAccess mocks base method.
|
||||
func (m *MockManager) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -77,6 +77,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -120,6 +123,9 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
|
||||
@@ -1034,7 +1034,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving policy with rule groups with no peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1065,7 +1065,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving policy where source has peers but destination does not", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1097,7 +1097,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving policy where destination has peers but source does not", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1129,7 +1129,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving policy with source and destination groups with peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1160,7 +1160,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1180,7 +1180,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1201,7 +1201,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1220,7 +1220,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1240,7 +1240,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting policy where destination has peers but source does not", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1258,7 +1258,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting policy with no peers in groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
@@ -80,6 +80,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
|
||||
@@ -180,7 +180,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving unused posture check", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -198,7 +198,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("updating unused posture check", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -235,7 +235,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("linking posture check to policy with peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -264,7 +264,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -282,7 +282,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("removing posture check from policy", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -301,7 +301,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting unused posture check", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -337,7 +337,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -381,7 +381,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
peerShouldReceiveUpdate(t, updMsg1.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -420,7 +420,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
@@ -192,6 +192,9 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -246,6 +249,9 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||
|
||||
if oldRouteAffectsPeers || newRouteAffectsPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -289,6 +295,9 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
||||
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
|
||||
@@ -1998,7 +1998,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -2034,7 +2034,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -2070,7 +2070,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("creating route with a routing peer", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -2095,7 +2095,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -2113,7 +2113,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting route", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -2149,7 +2149,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -2189,7 +2189,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
@@ -431,7 +431,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("creating setup key", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -449,7 +449,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving setup key", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.NetworkMap)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
1542
management/server/store/sqlstore_bench_test.go
Normal file
1542
management/server/store/sqlstore_bench_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -22,6 +22,9 @@ type UpdateChannelMetrics struct {
|
||||
calcPeerNetworkMapDurationMs metric.Int64Histogram
|
||||
mergeNetworkMapDurationMicro metric.Int64Histogram
|
||||
toSyncResponseDurationMicro metric.Int64Histogram
|
||||
bufferPushCounter metric.Int64Counter
|
||||
bufferOverwriteCounter metric.Int64Counter
|
||||
bufferIgnoreCounter metric.Int64Counter
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
@@ -125,6 +128,27 @@ func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateCh
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bufferPushCounter, err := meter.Int64Counter("management.updatechannel.buffer.push.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of updates pushed to an empty buffer"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bufferOverwriteCounter, err := meter.Int64Counter("management.updatechannel.buffer.overwrite.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of updates overwriting old unsent updates in the buffer"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bufferIgnoreCounter, err := meter.Int64Counter("management.updatechannel.buffer.ignore.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of updates being ignored due to old network serial"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &UpdateChannelMetrics{
|
||||
createChannelDurationMicro: createChannelDurationMicro,
|
||||
closeChannelDurationMicro: closeChannelDurationMicro,
|
||||
@@ -138,6 +162,9 @@ func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateCh
|
||||
calcPeerNetworkMapDurationMs: calcPeerNetworkMapDurationMs,
|
||||
mergeNetworkMapDurationMicro: mergeNetworkMapDurationMicro,
|
||||
toSyncResponseDurationMicro: toSyncResponseDurationMicro,
|
||||
bufferPushCounter: bufferPushCounter,
|
||||
bufferOverwriteCounter: bufferOverwriteCounter,
|
||||
bufferIgnoreCounter: bufferIgnoreCounter,
|
||||
ctx: ctx,
|
||||
}, nil
|
||||
}
|
||||
@@ -193,3 +220,18 @@ func (metrics *UpdateChannelMetrics) CountMergeNetworkMapDuration(duration time.
|
||||
func (metrics *UpdateChannelMetrics) CountToSyncResponseDuration(duration time.Duration) {
|
||||
metrics.toSyncResponseDurationMicro.Record(metrics.ctx, duration.Microseconds())
|
||||
}
|
||||
|
||||
// CountBufferPush counts how many buffer push operations are happening on an empty buffer
|
||||
func (metrics *UpdateChannelMetrics) CountBufferPush() {
|
||||
metrics.bufferPushCounter.Add(metrics.ctx, 1)
|
||||
}
|
||||
|
||||
// CountBufferOverwrite counts how many buffer overwrite operations are happening on a non-empty buffer
|
||||
func (metrics *UpdateChannelMetrics) CountBufferOverwrite() {
|
||||
metrics.bufferOverwriteCounter.Add(metrics.ctx, 1)
|
||||
}
|
||||
|
||||
// CountBufferIgnore counts how many buffer ignore operations are happening when a new update is pushed
|
||||
func (metrics *UpdateChannelMetrics) CountBufferIgnore() {
|
||||
metrics.bufferIgnoreCounter.Add(metrics.ctx, 1)
|
||||
}
|
||||
|
||||
@@ -227,7 +227,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont
|
||||
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
||||
|
||||
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
|
||||
m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
|
||||
m.updateManager.SendImportantUpdate(ctx, peerID, &UpdateMessage{Update: update})
|
||||
}
|
||||
|
||||
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) {
|
||||
@@ -251,7 +251,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac
|
||||
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
||||
|
||||
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
|
||||
m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
|
||||
m.updateManager.SendImportantUpdate(ctx, peerID, &UpdateMessage{Update: update})
|
||||
}
|
||||
|
||||
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {
|
||||
|
||||
@@ -121,7 +121,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
|
||||
loop:
|
||||
for timeout := time.After(5 * time.Second); ; {
|
||||
select {
|
||||
case update := <-updateChannel:
|
||||
case update := <-updateChannel.Important:
|
||||
updates = append(updates, update)
|
||||
case <-timeout:
|
||||
break loop
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
@@ -21,7 +22,6 @@ import (
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
@@ -87,6 +87,13 @@ type Account struct {
|
||||
NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"`
|
||||
NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"`
|
||||
Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"`
|
||||
|
||||
NetworkMapCache *NetworkMapBuilder `gorm:"-"`
|
||||
nmapInitOnce *sync.Once `gorm:"-"`
|
||||
}
|
||||
|
||||
func (a *Account) InitOnce() {
|
||||
a.nmapInitOnce = &sync.Once{}
|
||||
}
|
||||
|
||||
// this class is used by gorm only
|
||||
@@ -124,109 +131,6 @@ func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool {
|
||||
o.SignupFormPending == onboarding.SignupFormPending
|
||||
}
|
||||
|
||||
// GetRoutesToSync returns the enabled routes for the peer ID and the routes
|
||||
// from the ACL peers that have distribution groups associated with the peer ID.
|
||||
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
|
||||
func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer) []*route.Route {
|
||||
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID)
|
||||
peerRoutesMembership := make(LookupMap)
|
||||
for _, r := range append(routes, peerDisabledRoutes...) {
|
||||
peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{}
|
||||
}
|
||||
|
||||
groupListMap := a.GetPeerGroups(peerID)
|
||||
for _, peer := range aclPeers {
|
||||
activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID)
|
||||
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap)
|
||||
filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
|
||||
routes = append(routes, filteredRoutes...)
|
||||
}
|
||||
|
||||
return routes
|
||||
}
|
||||
|
||||
// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership
|
||||
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route {
|
||||
var filteredRoutes []*route.Route
|
||||
for _, r := range routes {
|
||||
_, found := peerMemberships[string(r.GetHAUniqueID())]
|
||||
if !found {
|
||||
filteredRoutes = append(filteredRoutes, r)
|
||||
}
|
||||
}
|
||||
return filteredRoutes
|
||||
}
|
||||
|
||||
// filterRoutesByGroups returns a list with routes that have distribution groups in the group's map
|
||||
func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
|
||||
var filteredRoutes []*route.Route
|
||||
for _, r := range routes {
|
||||
for _, groupID := range r.Groups {
|
||||
_, found := groupListMap[groupID]
|
||||
if found {
|
||||
filteredRoutes = append(filteredRoutes, r)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return filteredRoutes
|
||||
}
|
||||
|
||||
// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves
|
||||
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
|
||||
// If the given is not a routing peer, then the lists are empty.
|
||||
func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) {
|
||||
|
||||
peer := a.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id)
|
||||
return enabledRoutes, disabledRoutes
|
||||
}
|
||||
|
||||
seenRoute := make(map[route.ID]struct{})
|
||||
|
||||
takeRoute := func(r *route.Route, id string) {
|
||||
if _, ok := seenRoute[r.ID]; ok {
|
||||
return
|
||||
}
|
||||
seenRoute[r.ID] = struct{}{}
|
||||
|
||||
if r.Enabled {
|
||||
r.Peer = peer.Key
|
||||
enabledRoutes = append(enabledRoutes, r)
|
||||
return
|
||||
}
|
||||
disabledRoutes = append(disabledRoutes, r)
|
||||
}
|
||||
|
||||
for _, r := range a.Routes {
|
||||
for _, groupID := range r.PeerGroups {
|
||||
group := a.GetGroup(groupID)
|
||||
if group == nil {
|
||||
log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id)
|
||||
continue
|
||||
}
|
||||
for _, id := range group.Peers {
|
||||
if id != peerID {
|
||||
continue
|
||||
}
|
||||
|
||||
newPeerRoute := r.Copy()
|
||||
newPeerRoute.Peer = id
|
||||
newPeerRoute.PeerGroups = nil
|
||||
newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map
|
||||
takeRoute(newPeerRoute, id)
|
||||
break
|
||||
}
|
||||
}
|
||||
if r.Peer == peerID {
|
||||
takeRoute(r.Copy(), peerID)
|
||||
}
|
||||
}
|
||||
|
||||
return enabledRoutes, disabledRoutes
|
||||
}
|
||||
|
||||
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
|
||||
func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route {
|
||||
var routes []*route.Route
|
||||
@@ -246,174 +150,6 @@ func (a *Account) GetGroup(groupID string) *Group {
|
||||
return a.Groups[groupID]
|
||||
}
|
||||
|
||||
// GetPeerNetworkMap returns the networkmap for the given peer ID.
|
||||
func (a *Account) GetPeerNetworkMap(
|
||||
ctx context.Context,
|
||||
peerID string,
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
validatedPeersMap map[string]struct{},
|
||||
resourcePolicies map[string][]*Policy,
|
||||
routers map[string]map[string]*routerTypes.NetworkRouter,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *NetworkMap {
|
||||
start := time.Now()
|
||||
|
||||
peer := a.Peers[peerID]
|
||||
if peer == nil {
|
||||
return &NetworkMap{
|
||||
Network: a.Network.Copy(),
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := validatedPeersMap[peerID]; !ok {
|
||||
return &NetworkMap{
|
||||
Network: a.Network.Copy(),
|
||||
}
|
||||
}
|
||||
|
||||
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap)
|
||||
// exclude expired peers
|
||||
var peersToConnect []*nbpeer.Peer
|
||||
var expiredPeers []*nbpeer.Peer
|
||||
for _, p := range aclPeers {
|
||||
expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration)
|
||||
if a.Settings.PeerLoginExpirationEnabled && expired {
|
||||
expiredPeers = append(expiredPeers, p)
|
||||
continue
|
||||
}
|
||||
peersToConnect = append(peersToConnect, p)
|
||||
}
|
||||
|
||||
routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect)
|
||||
routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap)
|
||||
isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers)
|
||||
var networkResourcesFirewallRules []*RouteFirewallRule
|
||||
if isRouter {
|
||||
networkResourcesFirewallRules = a.GetPeerNetworkResourceFirewallRules(ctx, peer, validatedPeersMap, networkResourcesRoutes, resourcePolicies)
|
||||
}
|
||||
peersToConnectIncludingRouters := a.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, isRouter, sourcePeers)
|
||||
|
||||
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
|
||||
dnsUpdate := nbdns.Config{
|
||||
ServiceEnable: dnsManagementStatus,
|
||||
}
|
||||
|
||||
if dnsManagementStatus {
|
||||
var zones []nbdns.CustomZone
|
||||
if peersCustomZone.Domain != "" {
|
||||
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect)
|
||||
zones = append(zones, nbdns.CustomZone{
|
||||
Domain: peersCustomZone.Domain,
|
||||
Records: records,
|
||||
})
|
||||
}
|
||||
dnsUpdate.CustomZones = zones
|
||||
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
|
||||
}
|
||||
|
||||
nm := &NetworkMap{
|
||||
Peers: peersToConnectIncludingRouters,
|
||||
Network: a.Network.Copy(),
|
||||
Routes: slices.Concat(networkResourcesRoutes, routesUpdate),
|
||||
DNSConfig: dnsUpdate,
|
||||
OfflinePeers: expiredPeers,
|
||||
FirewallRules: firewallRules,
|
||||
RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules),
|
||||
}
|
||||
|
||||
if metrics != nil {
|
||||
objectCount := int64(len(peersToConnectIncludingRouters) + len(expiredPeers) + len(routesUpdate) + len(networkResourcesRoutes) + len(firewallRules) + +len(networkResourcesFirewallRules) + len(routesFirewallRules))
|
||||
metrics.CountNetworkMapObjects(objectCount)
|
||||
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
if objectCount > 5000 {
|
||||
log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+
|
||||
"peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d, network resources routes: %d, network resources firewall rules: %d, routes firewall rules: %d",
|
||||
a.Id, objectCount, len(peersToConnectIncludingRouters), len(expiredPeers), len(routesUpdate), len(firewallRules), len(networkResourcesRoutes), len(networkResourcesFirewallRules), len(routesFirewallRules))
|
||||
}
|
||||
}
|
||||
|
||||
return nm
|
||||
}
|
||||
|
||||
func (a *Account) addNetworksRoutingPeers(
|
||||
networkResourcesRoutes []*route.Route,
|
||||
peer *nbpeer.Peer,
|
||||
peersToConnect []*nbpeer.Peer,
|
||||
expiredPeers []*nbpeer.Peer,
|
||||
isRouter bool,
|
||||
sourcePeers map[string]struct{},
|
||||
) []*nbpeer.Peer {
|
||||
|
||||
networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes))
|
||||
for _, r := range networkResourcesRoutes {
|
||||
networkRoutesPeers[r.PeerID] = struct{}{}
|
||||
}
|
||||
|
||||
delete(sourcePeers, peer.ID)
|
||||
delete(networkRoutesPeers, peer.ID)
|
||||
|
||||
for _, existingPeer := range peersToConnect {
|
||||
delete(sourcePeers, existingPeer.ID)
|
||||
delete(networkRoutesPeers, existingPeer.ID)
|
||||
}
|
||||
for _, expPeer := range expiredPeers {
|
||||
delete(sourcePeers, expPeer.ID)
|
||||
delete(networkRoutesPeers, expPeer.ID)
|
||||
}
|
||||
|
||||
missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers))
|
||||
if isRouter {
|
||||
for p := range sourcePeers {
|
||||
missingPeers[p] = struct{}{}
|
||||
}
|
||||
}
|
||||
for p := range networkRoutesPeers {
|
||||
missingPeers[p] = struct{}{}
|
||||
}
|
||||
|
||||
for p := range missingPeers {
|
||||
if missingPeer := a.Peers[p]; missingPeer != nil {
|
||||
peersToConnect = append(peersToConnect, missingPeer)
|
||||
}
|
||||
}
|
||||
|
||||
return peersToConnect
|
||||
}
|
||||
|
||||
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
|
||||
groupList := account.GetPeerGroups(peerID)
|
||||
|
||||
var peerNSGroups []*nbdns.NameServerGroup
|
||||
|
||||
for _, nsGroup := range account.NameServerGroups {
|
||||
if !nsGroup.Enabled {
|
||||
continue
|
||||
}
|
||||
for _, gID := range nsGroup.Groups {
|
||||
_, found := groupList[gID]
|
||||
if found {
|
||||
if !peerIsNameserver(account.GetPeer(peerID), nsGroup) {
|
||||
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return peerNSGroups
|
||||
}
|
||||
|
||||
// peerIsNameserver returns true if the peer is a nameserver for a nsGroup
|
||||
func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
if peer.IP.Equal(ns.IP.AsSlice()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func AddPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels LookupMap) {
|
||||
for _, peer := range account.Peers {
|
||||
label, err := GetPeerHostLabel(peer.Name, peerLabels)
|
||||
@@ -760,32 +496,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string {
|
||||
return grps
|
||||
}
|
||||
|
||||
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
|
||||
peerGroups := a.GetPeerGroups(peerID)
|
||||
enabled := true
|
||||
for _, groupID := range a.DNSSettings.DisabledManagementGroups {
|
||||
_, found := peerGroups[groupID]
|
||||
if found {
|
||||
enabled = false
|
||||
break
|
||||
}
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
|
||||
func (a *Account) GetPeerGroups(peerID string) LookupMap {
|
||||
groupList := make(LookupMap)
|
||||
for groupID, group := range a.Groups {
|
||||
for _, id := range group.Peers {
|
||||
if id == peerID {
|
||||
groupList[groupID] = struct{}{}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return groupList
|
||||
}
|
||||
|
||||
func (a *Account) GetTakenIPs() []net.IP {
|
||||
var takenIps []net.IP
|
||||
for _, existingPeer := range a.Peers {
|
||||
@@ -890,6 +600,7 @@ func (a *Account) Copy() *Account {
|
||||
NetworkRouters: networkRouters,
|
||||
NetworkResources: networkResources,
|
||||
Onboarding: a.Onboarding,
|
||||
NetworkMapCache: a.NetworkMapCache,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -985,192 +696,6 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
|
||||
return groupUpdates
|
||||
}
|
||||
|
||||
// GetPeerConnectionResources for a given peer
|
||||
//
|
||||
// This function returns the list of peers and firewall rules that are applicable to a given peer.
|
||||
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer)
|
||||
|
||||
for _, policy := range a.Policies {
|
||||
if !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
var sourcePeers, destinationPeers []*nbpeer.Peer
|
||||
var peerInSources, peerInDestinations bool
|
||||
|
||||
if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
sourcePeers, peerInSources = a.getPeerFromResource(rule.SourceResource, peer.ID)
|
||||
} else {
|
||||
sourcePeers, peerInSources = a.getAllPeersFromGroups(ctx, rule.Sources, peer.ID, policy.SourcePostureChecks, validatedPeersMap)
|
||||
}
|
||||
|
||||
if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
destinationPeers, peerInDestinations = a.getPeerFromResource(rule.DestinationResource, peer.ID)
|
||||
} else {
|
||||
destinationPeers, peerInDestinations = a.getAllPeersFromGroups(ctx, rule.Destinations, peer.ID, nil, validatedPeersMap)
|
||||
}
|
||||
|
||||
if rule.Bidirectional {
|
||||
if peerInSources {
|
||||
generateResources(rule, destinationPeers, FirewallRuleDirectionIN)
|
||||
}
|
||||
if peerInDestinations {
|
||||
generateResources(rule, sourcePeers, FirewallRuleDirectionOUT)
|
||||
}
|
||||
}
|
||||
|
||||
if peerInSources {
|
||||
generateResources(rule, destinationPeers, FirewallRuleDirectionOUT)
|
||||
}
|
||||
|
||||
if peerInDestinations {
|
||||
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return getAccumulatedResources()
|
||||
}
|
||||
|
||||
// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls
|
||||
//
|
||||
// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer.
|
||||
// It safe to call the generator function multiple times for same peer and different rules no duplicates will be
|
||||
// generated. The accumulator function returns the result of all the generator calls.
|
||||
func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer.Peer) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
|
||||
rulesExists := make(map[string]struct{})
|
||||
peersExists := make(map[string]struct{})
|
||||
rules := make([]*FirewallRule, 0)
|
||||
peers := make([]*nbpeer.Peer, 0)
|
||||
|
||||
all, err := a.GetGroupAll()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get group all: %v", err)
|
||||
all = &Group{}
|
||||
}
|
||||
|
||||
return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) {
|
||||
isAll := (len(all.Peers) - 1) == len(groupPeers)
|
||||
for _, peer := range groupPeers {
|
||||
if peer == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := peersExists[peer.ID]; !ok {
|
||||
peers = append(peers, peer)
|
||||
peersExists[peer.ID] = struct{}{}
|
||||
}
|
||||
|
||||
fr := FirewallRule{
|
||||
PolicyID: rule.ID,
|
||||
PeerIP: peer.IP.String(),
|
||||
Direction: direction,
|
||||
Action: string(rule.Action),
|
||||
Protocol: string(rule.Protocol),
|
||||
}
|
||||
|
||||
if isAll {
|
||||
fr.PeerIP = "0.0.0.0"
|
||||
}
|
||||
|
||||
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
||||
fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")
|
||||
if _, ok := rulesExists[ruleID]; ok {
|
||||
continue
|
||||
}
|
||||
rulesExists[ruleID] = struct{}{}
|
||||
|
||||
if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 {
|
||||
rules = append(rules, &fr)
|
||||
continue
|
||||
}
|
||||
|
||||
rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...)
|
||||
}
|
||||
}, func() ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
return peers, rules
|
||||
}
|
||||
}
|
||||
|
||||
// getAllPeersFromGroups for given peer ID and list of groups
|
||||
//
|
||||
// Returns a list of peers from specified groups that pass specified posture checks
|
||||
// and a boolean indicating if the supplied peer ID exists within these groups.
|
||||
//
|
||||
// Important: Posture checks are applicable only to source group peers,
|
||||
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs
|
||||
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
|
||||
peerInGroups := false
|
||||
uniquePeerIDs := a.getUniquePeerIDsFromGroupsIDs(ctx, groups)
|
||||
filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs))
|
||||
for _, p := range uniquePeerIDs {
|
||||
peer, ok := a.Peers[p]
|
||||
if !ok || peer == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// validate the peer based on policy posture checks applied
|
||||
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
|
||||
if !isValid {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := validatedPeersMap[peer.ID]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if peer.ID == peerID {
|
||||
peerInGroups = true
|
||||
continue
|
||||
}
|
||||
|
||||
filteredPeers = append(filteredPeers, peer)
|
||||
}
|
||||
|
||||
return filteredPeers, peerInGroups
|
||||
}
|
||||
|
||||
func (a *Account) getPeerFromResource(resource Resource, peerID string) ([]*nbpeer.Peer, bool) {
|
||||
peer := a.GetPeer(resource.ID)
|
||||
if peer == nil {
|
||||
return []*nbpeer.Peer{}, false
|
||||
}
|
||||
|
||||
return []*nbpeer.Peer{peer}, resource.ID == peerID
|
||||
}
|
||||
|
||||
// validatePostureChecksOnPeer validates the posture checks on a peer
|
||||
func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool {
|
||||
peer, ok := a.Peers[peerID]
|
||||
if !ok && peer == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, postureChecksID := range sourcePostureChecksID {
|
||||
postureChecks := a.GetPostureChecks(postureChecksID)
|
||||
if postureChecks == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, check := range postureChecks.GetChecks() {
|
||||
isValid, err := check.Check(ctx, *peer)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error())
|
||||
}
|
||||
if !isValid {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (a *Account) GetPostureChecks(postureChecksID string) *posture.Checks {
|
||||
for _, postureChecks := range a.PostureChecks {
|
||||
if postureChecks.ID == postureChecksID {
|
||||
@@ -1180,174 +705,6 @@ func (a *Account) GetPostureChecks(postureChecksID string) *posture.Checks {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account.
|
||||
func (a *Account) GetPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule {
|
||||
routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes))
|
||||
|
||||
enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID)
|
||||
for _, route := range enabledRoutes {
|
||||
// If no access control groups are specified, accept all traffic.
|
||||
if len(route.AccessControlGroups) == 0 {
|
||||
defaultPermit := getDefaultPermit(route)
|
||||
routesFirewallRules = append(routesFirewallRules, defaultPermit...)
|
||||
continue
|
||||
}
|
||||
|
||||
distributionPeers := a.getDistributionGroupsPeers(route)
|
||||
|
||||
for _, accessGroup := range route.AccessControlGroups {
|
||||
policies := GetAllRoutePoliciesFromGroups(a, []string{accessGroup})
|
||||
rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers)
|
||||
routesFirewallRules = append(routesFirewallRules, rules...)
|
||||
}
|
||||
}
|
||||
|
||||
return routesFirewallRules
|
||||
}
|
||||
|
||||
func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule {
|
||||
var fwRules []*RouteFirewallRule
|
||||
for _, policy := range policies {
|
||||
if !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
rulePeers := a.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap)
|
||||
rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN)
|
||||
fwRules = append(fwRules, rules...)
|
||||
}
|
||||
}
|
||||
return fwRules
|
||||
}
|
||||
|
||||
func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer {
|
||||
distPeersWithPolicy := make(map[string]struct{})
|
||||
for _, id := range rule.Sources {
|
||||
group := a.Groups[id]
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, pID := range group.Peers {
|
||||
if pID == peerID {
|
||||
continue
|
||||
}
|
||||
_, distPeer := distributionPeers[pID]
|
||||
_, valid := validatedPeersMap[pID]
|
||||
if distPeer && valid && a.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) {
|
||||
distPeersWithPolicy[pID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy))
|
||||
for pID := range distPeersWithPolicy {
|
||||
peer := a.Peers[pID]
|
||||
if peer == nil {
|
||||
continue
|
||||
}
|
||||
distributionGroupPeers = append(distributionGroupPeers, peer)
|
||||
}
|
||||
return distributionGroupPeers
|
||||
}
|
||||
|
||||
func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} {
|
||||
distPeers := make(map[string]struct{})
|
||||
for _, id := range route.Groups {
|
||||
group := a.Groups[id]
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, pID := range group.Peers {
|
||||
distPeers[pID] = struct{}{}
|
||||
}
|
||||
}
|
||||
return distPeers
|
||||
}
|
||||
|
||||
func getDefaultPermit(route *route.Route) []*RouteFirewallRule {
|
||||
var rules []*RouteFirewallRule
|
||||
|
||||
sources := []string{"0.0.0.0/0"}
|
||||
if route.Network.Addr().Is6() {
|
||||
sources = []string{"::/0"}
|
||||
}
|
||||
rule := RouteFirewallRule{
|
||||
SourceRanges: sources,
|
||||
Action: string(PolicyTrafficActionAccept),
|
||||
Destination: route.Network.String(),
|
||||
Protocol: string(PolicyRuleProtocolALL),
|
||||
Domains: route.Domains,
|
||||
IsDynamic: route.IsDynamic(),
|
||||
RouteID: route.ID,
|
||||
}
|
||||
|
||||
rules = append(rules, &rule)
|
||||
|
||||
// dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally
|
||||
if route.IsDynamic() {
|
||||
ruleV6 := rule
|
||||
ruleV6.SourceRanges = []string{"::/0"}
|
||||
rules = append(rules, &ruleV6)
|
||||
}
|
||||
|
||||
return rules
|
||||
}
|
||||
|
||||
// GetAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups
|
||||
// and returns a list of policies that have rules with destinations matching the specified groups.
|
||||
func GetAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy {
|
||||
routePolicies := make([]*Policy, 0)
|
||||
for _, groupID := range accessControlGroups {
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, policy := range account.Policies {
|
||||
for _, rule := range policy.Rules {
|
||||
exist := slices.ContainsFunc(rule.Destinations, func(groupID string) bool {
|
||||
return groupID == group.ID
|
||||
})
|
||||
if exist {
|
||||
routePolicies = append(routePolicies, policy)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return routePolicies
|
||||
}
|
||||
|
||||
// GetPeerNetworkResourceFirewallRules gets the network resources firewall rules associated with a routing peer ID for the account.
|
||||
func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, routes []*route.Route, resourcePolicies map[string][]*Policy) []*RouteFirewallRule {
|
||||
routesFirewallRules := make([]*RouteFirewallRule, 0)
|
||||
|
||||
for _, route := range routes {
|
||||
if route.Peer != peer.Key {
|
||||
continue
|
||||
}
|
||||
resourceAppliedPolicies := resourcePolicies[string(route.GetResourceID())]
|
||||
distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups)
|
||||
|
||||
rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers)
|
||||
for _, rule := range rules {
|
||||
if len(rule.SourceRanges) > 0 {
|
||||
routesFirewallRules = append(routesFirewallRules, rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return routesFirewallRules
|
||||
}
|
||||
|
||||
// getNetworkResourceGroups retrieves all groups associated with the given network resource.
|
||||
func (a *Account) getNetworkResourceGroups(resourceID string) []*Group {
|
||||
var networkResourceGroups []*Group
|
||||
@@ -1377,91 +734,6 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy {
|
||||
return resourcePolicies
|
||||
}
|
||||
|
||||
// GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers.
|
||||
func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) {
|
||||
var isRoutingPeer bool
|
||||
var routes []*route.Route
|
||||
allSourcePeers := make(map[string]struct{}, len(a.Peers))
|
||||
|
||||
for _, resource := range a.NetworkResources {
|
||||
if !resource.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
var addSourcePeers bool
|
||||
|
||||
networkRoutingPeers, exists := routers[resource.NetworkID]
|
||||
if exists {
|
||||
if router, ok := networkRoutingPeers[peerID]; ok {
|
||||
isRoutingPeer, addSourcePeers = true, true
|
||||
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerID, router, resourcePolicies)...)
|
||||
}
|
||||
}
|
||||
|
||||
addedResourceRoute := false
|
||||
for _, policy := range resourcePolicies[resource.ID] {
|
||||
var peers []string
|
||||
if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" {
|
||||
peers = []string{policy.Rules[0].SourceResource.ID}
|
||||
} else {
|
||||
peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups())
|
||||
}
|
||||
if addSourcePeers {
|
||||
for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) {
|
||||
allSourcePeers[pID] = struct{}{}
|
||||
}
|
||||
} else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) {
|
||||
// add routes for the resource if the peer is in the distribution group
|
||||
for peerId, router := range networkRoutingPeers {
|
||||
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
|
||||
}
|
||||
addedResourceRoute = true
|
||||
}
|
||||
if addedResourceRoute {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return isRoutingPeer, routes, allSourcePeers
|
||||
}
|
||||
|
||||
func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string {
|
||||
var dest []string
|
||||
for _, peerID := range inputPeers {
|
||||
if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) {
|
||||
dest = append(dest, peerID)
|
||||
}
|
||||
}
|
||||
return dest
|
||||
}
|
||||
|
||||
func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string {
|
||||
peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity
|
||||
for _, groupID := range groups {
|
||||
group := a.GetGroup(groupID)
|
||||
if group == nil {
|
||||
log.WithContext(ctx).Warnf("group %s doesn't exist under account %s, will continue map generation without it", groupID, a.Id)
|
||||
continue
|
||||
}
|
||||
|
||||
if group.IsGroupAll() || len(groups) == 1 {
|
||||
return group.Peers
|
||||
}
|
||||
|
||||
for _, peerID := range group.Peers {
|
||||
peerIDs[peerID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
ids := make([]string, 0, len(peerIDs))
|
||||
for peerID := range peerIDs {
|
||||
ids = append(ids, peerID)
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
// getNetworkResources filters and returns a list of network resources associated with the given network ID.
|
||||
func (a *Account) getNetworkResources(networkID string) []*resourceTypes.NetworkResource {
|
||||
var resources []*resourceTypes.NetworkResource
|
||||
@@ -1527,22 +799,6 @@ func (a *Account) GetPoliciesAppliedInNetwork(networkID string) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
// getNetworkResourcesRoutes convert the network resources list to routes list.
|
||||
func (a *Account) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerId string, router *routerTypes.NetworkRouter, resourcePolicies map[string][]*Policy) []*route.Route {
|
||||
resourceAppliedPolicies := resourcePolicies[resource.ID]
|
||||
|
||||
var routes []*route.Route
|
||||
// distribute the resource routes only if there is policy applied to it
|
||||
if len(resourceAppliedPolicies) > 0 {
|
||||
peer := a.GetPeer(peerId)
|
||||
if peer != nil {
|
||||
routes = append(routes, resource.ToRoute(peer, router))
|
||||
}
|
||||
}
|
||||
|
||||
return routes
|
||||
}
|
||||
|
||||
func (a *Account) GetResourceRoutersMap() map[string]map[string]*routerTypes.NetworkRouter {
|
||||
routers := make(map[string]map[string]*routerTypes.NetworkRouter)
|
||||
|
||||
@@ -1573,28 +829,6 @@ func (a *Account) GetResourceRoutersMap() map[string]map[string]*routerTypes.Net
|
||||
return routers
|
||||
}
|
||||
|
||||
// getPoliciesSourcePeers collects all unique peers from the source groups defined in the given policies.
|
||||
func getPoliciesSourcePeers(policies []*Policy, groups map[string]*Group) map[string]struct{} {
|
||||
sourcePeers := make(map[string]struct{})
|
||||
|
||||
for _, policy := range policies {
|
||||
for _, rule := range policy.Rules {
|
||||
for _, sourceGroup := range rule.Sources {
|
||||
group := groups[sourceGroup]
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, peer := range group.Peers {
|
||||
sourcePeers[peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sourcePeers
|
||||
}
|
||||
|
||||
// AddAllGroup to account object if it doesn't exist
|
||||
func (a *Account) AddAllGroup(disableDefaultPolicy bool) error {
|
||||
if len(a.Groups) == 0 {
|
||||
@@ -1638,66 +872,3 @@ func (a *Account) AddAllGroup(disableDefaultPolicy bool) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
|
||||
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
|
||||
var expanded []*FirewallRule
|
||||
|
||||
if len(rule.Ports) > 0 {
|
||||
for _, port := range rule.Ports {
|
||||
fr := base
|
||||
fr.Port = port
|
||||
expanded = append(expanded, &fr)
|
||||
}
|
||||
return expanded
|
||||
}
|
||||
|
||||
supportPortRanges := peerSupportsPortRanges(peer.Meta.WtVersion)
|
||||
for _, portRange := range rule.PortRanges {
|
||||
fr := base
|
||||
|
||||
if supportPortRanges {
|
||||
fr.PortRange = portRange
|
||||
} else {
|
||||
// Peer doesn't support port ranges, only allow single-port ranges
|
||||
if portRange.Start != portRange.End {
|
||||
continue
|
||||
}
|
||||
fr.Port = strconv.FormatUint(uint64(portRange.Start), 10)
|
||||
}
|
||||
expanded = append(expanded, &fr)
|
||||
}
|
||||
|
||||
return expanded
|
||||
}
|
||||
|
||||
// peerSupportsPortRanges checks if the peer version supports port ranges.
|
||||
func peerSupportsPortRanges(peerVer string) bool {
|
||||
if strings.Contains(peerVer, "dev") {
|
||||
return true
|
||||
}
|
||||
|
||||
meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer)
|
||||
return err == nil && meetMinVer
|
||||
}
|
||||
|
||||
// filterZoneRecordsForPeers filters DNS records to only include peers to connect.
|
||||
func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord {
|
||||
filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records))
|
||||
peerIPs := make(map[string]struct{})
|
||||
|
||||
// Add peer's own IP to include its own DNS records
|
||||
peerIPs[peer.IP.String()] = struct{}{}
|
||||
|
||||
for _, peerToConnect := range peersToConnect {
|
||||
peerIPs[peerToConnect.IP.String()] = struct{}{}
|
||||
}
|
||||
|
||||
for _, record := range customZone.Records {
|
||||
if _, exists := peerIPs[record.RData]; exists {
|
||||
filteredRecords = append(filteredRecords, record)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredRecords
|
||||
}
|
||||
|
||||
@@ -845,6 +845,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
|
||||
peer *nbpeer.Peer
|
||||
customZone nbdns.CustomZone
|
||||
peersToConnect []*nbpeer.Peer
|
||||
expiredPeers []*nbpeer.Peer
|
||||
expectedRecords []nbdns.SimpleRecord
|
||||
}{
|
||||
{
|
||||
@@ -857,6 +858,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
|
||||
},
|
||||
},
|
||||
peersToConnect: []*nbpeer.Peer{},
|
||||
expiredPeers: []*nbpeer.Peer{},
|
||||
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
|
||||
expectedRecords: []nbdns.SimpleRecord{
|
||||
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
|
||||
@@ -890,7 +892,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
|
||||
}
|
||||
return peers
|
||||
}(),
|
||||
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
|
||||
expiredPeers: []*nbpeer.Peer{},
|
||||
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
|
||||
expectedRecords: func() []nbdns.SimpleRecord {
|
||||
var records []nbdns.SimpleRecord
|
||||
for _, i := range []int{1, 5, 10, 25, 50, 75, 100} {
|
||||
@@ -924,7 +927,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
|
||||
{ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}},
|
||||
{ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}},
|
||||
},
|
||||
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
|
||||
expiredPeers: []*nbpeer.Peer{},
|
||||
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
|
||||
expectedRecords: []nbdns.SimpleRecord{
|
||||
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
{Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
@@ -934,11 +938,35 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
|
||||
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "expired peers are included in DNS entries",
|
||||
customZone: nbdns.CustomZone{
|
||||
Domain: "netbird.cloud.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
{Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
|
||||
{Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"},
|
||||
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
|
||||
},
|
||||
},
|
||||
peersToConnect: []*nbpeer.Peer{
|
||||
{ID: "peer1", IP: net.ParseIP("10.0.0.1")},
|
||||
},
|
||||
expiredPeers: []*nbpeer.Peer{
|
||||
{ID: "expired-peer", IP: net.ParseIP("10.0.0.99")},
|
||||
},
|
||||
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
|
||||
expectedRecords: []nbdns.SimpleRecord{
|
||||
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
{Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"},
|
||||
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect)
|
||||
result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect, tt.expiredPeers)
|
||||
assert.Equal(t, len(tt.expectedRecords), len(result))
|
||||
assert.ElementsMatch(t, tt.expectedRecords, result)
|
||||
})
|
||||
|
||||
43
management/server/types/holder.go
Normal file
43
management/server/types/holder.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Holder struct {
|
||||
mu sync.RWMutex
|
||||
accounts map[string]*Account
|
||||
}
|
||||
|
||||
func NewHolder() *Holder {
|
||||
return &Holder{
|
||||
accounts: make(map[string]*Account),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Holder) GetAccount(id string) *Account {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.accounts[id]
|
||||
}
|
||||
|
||||
func (h *Holder) AddAccount(account *Account) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.accounts[account.Id] = account
|
||||
}
|
||||
|
||||
func (h *Holder) LoadOrStoreFunc(id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if acc, ok := h.accounts[id]; ok {
|
||||
return acc, nil
|
||||
}
|
||||
account, err := accGetter(context.Background(), id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.accounts[id] = account
|
||||
return account, nil
|
||||
}
|
||||
@@ -9,12 +9,7 @@ import (
|
||||
|
||||
"github.com/c-robinson/iplib"
|
||||
"github.com/rs/xid"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
@@ -29,40 +24,6 @@ const (
|
||||
AllowedIPsFormat = "%s/32"
|
||||
)
|
||||
|
||||
type NetworkMap struct {
|
||||
Peers []*nbpeer.Peer
|
||||
Network *Network
|
||||
Routes []*route.Route
|
||||
DNSConfig nbdns.Config
|
||||
OfflinePeers []*nbpeer.Peer
|
||||
FirewallRules []*FirewallRule
|
||||
RoutesFirewallRules []*RouteFirewallRule
|
||||
ForwardingRules []*ForwardingRule
|
||||
}
|
||||
|
||||
func (nm *NetworkMap) Merge(other *NetworkMap) {
|
||||
nm.Peers = mergeUniquePeersByID(nm.Peers, other.Peers)
|
||||
nm.Routes = util.MergeUnique(nm.Routes, other.Routes)
|
||||
nm.OfflinePeers = mergeUniquePeersByID(nm.OfflinePeers, other.OfflinePeers)
|
||||
nm.FirewallRules = util.MergeUnique(nm.FirewallRules, other.FirewallRules)
|
||||
nm.RoutesFirewallRules = util.MergeUnique(nm.RoutesFirewallRules, other.RoutesFirewallRules)
|
||||
nm.ForwardingRules = util.MergeUnique(nm.ForwardingRules, other.ForwardingRules)
|
||||
}
|
||||
|
||||
func mergeUniquePeersByID(peers1, peers2 []*nbpeer.Peer) []*nbpeer.Peer {
|
||||
result := make(map[string]*nbpeer.Peer)
|
||||
for _, peer := range peers1 {
|
||||
result[peer.ID] = peer
|
||||
}
|
||||
for _, peer := range peers2 {
|
||||
if _, ok := result[peer.ID]; !ok {
|
||||
result[peer.ID] = peer
|
||||
}
|
||||
}
|
||||
|
||||
return maps.Values(result)
|
||||
}
|
||||
|
||||
type ForwardingRule struct {
|
||||
RuleProtocol string
|
||||
DestinationPorts RulePortRange
|
||||
|
||||
947
management/server/types/networkmap.go
Normal file
947
management/server/types/networkmap.go
Normal file
@@ -0,0 +1,947 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type NetworkMap struct {
|
||||
Peers []*nbpeer.Peer
|
||||
Network *Network
|
||||
Routes []*route.Route
|
||||
DNSConfig nbdns.Config
|
||||
OfflinePeers []*nbpeer.Peer
|
||||
FirewallRules []*FirewallRule
|
||||
RoutesFirewallRules []*RouteFirewallRule
|
||||
ForwardingRules []*ForwardingRule
|
||||
}
|
||||
|
||||
// var nmapPool = sync.Pool{
|
||||
// New: func() any {
|
||||
// return &NetworkMap{}
|
||||
// },
|
||||
// }
|
||||
|
||||
// func GetNetworkMap() *NetworkMap {
|
||||
// n := nmapPool.Get().(*NetworkMap)
|
||||
// n.reset()
|
||||
// return n
|
||||
// }
|
||||
|
||||
// func PutNetworkMap(n *NetworkMap) {
|
||||
// nmapPool.Put(n)
|
||||
// }
|
||||
|
||||
// func (n *NetworkMap) reset() {
|
||||
// n.Peers = n.Peers[:0]
|
||||
// n.Network = nil
|
||||
// n.Routes = n.Routes[:0]
|
||||
// n.DNSConfig = nbdns.Config{}
|
||||
// n.OfflinePeers = n.OfflinePeers[:0]
|
||||
// n.FirewallRules = n.FirewallRules[:0]
|
||||
// n.RoutesFirewallRules = n.RoutesFirewallRules[:0]
|
||||
// n.ForwardingRules = n.ForwardingRules[:0]
|
||||
// }
|
||||
|
||||
func (nm *NetworkMap) Merge(other *NetworkMap) {
|
||||
nm.Peers = mergeUniquePeersByID(nm.Peers, other.Peers)
|
||||
nm.Routes = util.MergeUnique(nm.Routes, other.Routes)
|
||||
nm.OfflinePeers = mergeUniquePeersByID(nm.OfflinePeers, other.OfflinePeers)
|
||||
nm.FirewallRules = util.MergeUnique(nm.FirewallRules, other.FirewallRules)
|
||||
nm.RoutesFirewallRules = util.MergeUnique(nm.RoutesFirewallRules, other.RoutesFirewallRules)
|
||||
nm.ForwardingRules = util.MergeUnique(nm.ForwardingRules, other.ForwardingRules)
|
||||
}
|
||||
|
||||
// TODO optimize
|
||||
func mergeUniquePeersByID(peers1, peers2 []*nbpeer.Peer) []*nbpeer.Peer {
|
||||
result := make(map[string]*nbpeer.Peer)
|
||||
for _, peer := range peers1 {
|
||||
result[peer.ID] = peer
|
||||
}
|
||||
for _, peer := range peers2 {
|
||||
if _, ok := result[peer.ID]; !ok {
|
||||
result[peer.ID] = peer
|
||||
}
|
||||
}
|
||||
|
||||
return maps.Values(result)
|
||||
}
|
||||
|
||||
// GetPeerNetworkMap returns the networkmap for the given peer ID.
|
||||
func (a *Account) GetPeerNetworkMap(
|
||||
ctx context.Context,
|
||||
peerID string,
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
validatedPeersMap map[string]struct{},
|
||||
resourcePolicies map[string][]*Policy,
|
||||
routers map[string]map[string]*routerTypes.NetworkRouter,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *NetworkMap {
|
||||
start := time.Now()
|
||||
|
||||
peer := a.Peers[peerID]
|
||||
if peer == nil {
|
||||
return &NetworkMap{
|
||||
Network: a.Network.Copy(),
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := validatedPeersMap[peerID]; !ok {
|
||||
return &NetworkMap{
|
||||
Network: a.Network.Copy(),
|
||||
}
|
||||
}
|
||||
|
||||
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap)
|
||||
// exclude expired peers
|
||||
var peersToConnect []*nbpeer.Peer
|
||||
var expiredPeers []*nbpeer.Peer
|
||||
for _, p := range aclPeers {
|
||||
expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration)
|
||||
if a.Settings.PeerLoginExpirationEnabled && expired {
|
||||
expiredPeers = append(expiredPeers, p)
|
||||
continue
|
||||
}
|
||||
peersToConnect = append(peersToConnect, p)
|
||||
}
|
||||
|
||||
routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect)
|
||||
routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap)
|
||||
isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers)
|
||||
var networkResourcesFirewallRules []*RouteFirewallRule
|
||||
if isRouter {
|
||||
networkResourcesFirewallRules = a.GetPeerNetworkResourceFirewallRules(ctx, peer, validatedPeersMap, networkResourcesRoutes, resourcePolicies)
|
||||
}
|
||||
peersToConnectIncludingRouters := a.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, isRouter, sourcePeers)
|
||||
|
||||
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
|
||||
dnsUpdate := nbdns.Config{
|
||||
ServiceEnable: dnsManagementStatus,
|
||||
}
|
||||
|
||||
if dnsManagementStatus {
|
||||
var zones []nbdns.CustomZone
|
||||
if peersCustomZone.Domain != "" {
|
||||
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers)
|
||||
zones = append(zones, nbdns.CustomZone{
|
||||
Domain: peersCustomZone.Domain,
|
||||
Records: records,
|
||||
})
|
||||
}
|
||||
dnsUpdate.CustomZones = zones
|
||||
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
|
||||
}
|
||||
|
||||
nm := &NetworkMap{
|
||||
Peers: peersToConnectIncludingRouters,
|
||||
Network: a.Network.Copy(),
|
||||
Routes: slices.Concat(networkResourcesRoutes, routesUpdate),
|
||||
DNSConfig: dnsUpdate,
|
||||
OfflinePeers: expiredPeers,
|
||||
FirewallRules: firewallRules,
|
||||
RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules),
|
||||
}
|
||||
|
||||
if metrics != nil {
|
||||
objectCount := int64(len(peersToConnectIncludingRouters) + len(expiredPeers) + len(routesUpdate) + len(networkResourcesRoutes) + len(firewallRules) + +len(networkResourcesFirewallRules) + len(routesFirewallRules))
|
||||
metrics.CountNetworkMapObjects(objectCount)
|
||||
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
if objectCount > 5000 {
|
||||
log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+
|
||||
"peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d, network resources routes: %d, network resources firewall rules: %d, routes firewall rules: %d",
|
||||
a.Id, objectCount, len(peersToConnectIncludingRouters), len(expiredPeers), len(routesUpdate), len(firewallRules), len(networkResourcesRoutes), len(networkResourcesFirewallRules), len(routesFirewallRules))
|
||||
}
|
||||
}
|
||||
|
||||
return nm
|
||||
}
|
||||
|
||||
// GetPeerConnectionResources for a given peer
|
||||
//
|
||||
// This function returns the list of peers and firewall rules that are applicable to a given peer.
|
||||
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer)
|
||||
|
||||
for _, policy := range a.Policies {
|
||||
if !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peer.ID, policy.SourcePostureChecks, validatedPeersMap)
|
||||
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peer.ID, nil, validatedPeersMap)
|
||||
|
||||
if rule.Bidirectional {
|
||||
if peerInSources {
|
||||
generateResources(rule, destinationPeers, FirewallRuleDirectionIN)
|
||||
}
|
||||
if peerInDestinations {
|
||||
generateResources(rule, sourcePeers, FirewallRuleDirectionOUT)
|
||||
}
|
||||
}
|
||||
|
||||
if peerInSources {
|
||||
generateResources(rule, destinationPeers, FirewallRuleDirectionOUT)
|
||||
}
|
||||
|
||||
if peerInDestinations {
|
||||
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return getAccumulatedResources()
|
||||
}
|
||||
|
||||
// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls
|
||||
//
|
||||
// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer.
|
||||
// It safe to call the generator function multiple times for same peer and different rules no duplicates will be
|
||||
// generated. The accumulator function returns the result of all the generator calls.
|
||||
func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer.Peer) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
|
||||
rulesExists := make(map[string]struct{})
|
||||
peersExists := make(map[string]struct{})
|
||||
rules := make([]*FirewallRule, 0)
|
||||
peers := make([]*nbpeer.Peer, 0)
|
||||
|
||||
all, err := a.GetGroupAll()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get group all: %v", err)
|
||||
all = &Group{}
|
||||
}
|
||||
|
||||
return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) {
|
||||
isAll := (len(all.Peers) - 1) == len(groupPeers)
|
||||
for _, peer := range groupPeers {
|
||||
if peer == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := peersExists[peer.ID]; !ok {
|
||||
peers = append(peers, peer)
|
||||
peersExists[peer.ID] = struct{}{}
|
||||
}
|
||||
|
||||
fr := FirewallRule{
|
||||
PolicyID: rule.ID,
|
||||
PeerIP: peer.IP.String(),
|
||||
Direction: direction,
|
||||
Action: string(rule.Action),
|
||||
Protocol: string(rule.Protocol),
|
||||
}
|
||||
|
||||
if isAll {
|
||||
fr.PeerIP = "0.0.0.0"
|
||||
}
|
||||
|
||||
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
||||
fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")
|
||||
if _, ok := rulesExists[ruleID]; ok {
|
||||
continue
|
||||
}
|
||||
rulesExists[ruleID] = struct{}{}
|
||||
|
||||
if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 {
|
||||
rules = append(rules, &fr)
|
||||
continue
|
||||
}
|
||||
|
||||
rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...)
|
||||
}
|
||||
}, func() ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
return peers, rules
|
||||
}
|
||||
}
|
||||
|
||||
// getAllPeersFromGroups for given peer ID and list of groups
|
||||
//
|
||||
// Returns a list of peers from specified groups that pass specified posture checks
|
||||
// and a boolean indicating if the supplied peer ID exists within these groups.
|
||||
//
|
||||
// Important: Posture checks are applicable only to source group peers,
|
||||
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs
|
||||
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
|
||||
peerInGroups := false
|
||||
uniquePeerIDs := a.getUniquePeerIDsFromGroupsIDs(ctx, groups)
|
||||
filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs))
|
||||
for _, p := range uniquePeerIDs {
|
||||
peer, ok := a.Peers[p]
|
||||
if !ok || peer == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// validate the peer based on policy posture checks applied
|
||||
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
|
||||
if !isValid {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := validatedPeersMap[peer.ID]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if peer.ID == peerID {
|
||||
peerInGroups = true
|
||||
continue
|
||||
}
|
||||
|
||||
filteredPeers = append(filteredPeers, peer)
|
||||
}
|
||||
|
||||
return filteredPeers, peerInGroups
|
||||
}
|
||||
|
||||
// validatePostureChecksOnPeer validates the posture checks on a peer
|
||||
func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool {
|
||||
peer, ok := a.Peers[peerID]
|
||||
if !ok && peer == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, postureChecksID := range sourcePostureChecksID {
|
||||
postureChecks := a.GetPostureChecks(postureChecksID)
|
||||
if postureChecks == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, check := range postureChecks.GetChecks() {
|
||||
isValid, err := check.Check(ctx, *peer)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error())
|
||||
}
|
||||
if !isValid {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
|
||||
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
|
||||
expanded := make([]*FirewallRule, 0, len(rule.Ports)+len(rule.PortRanges))
|
||||
|
||||
if len(rule.Ports) > 0 {
|
||||
for _, port := range rule.Ports {
|
||||
fr := base
|
||||
fr.Port = port
|
||||
expanded = append(expanded, &fr)
|
||||
}
|
||||
return expanded
|
||||
}
|
||||
|
||||
supportPortRanges := peerSupportsPortRanges(peer.Meta.WtVersion)
|
||||
for _, portRange := range rule.PortRanges {
|
||||
fr := base
|
||||
|
||||
if supportPortRanges {
|
||||
fr.PortRange = portRange
|
||||
} else {
|
||||
// Peer doesn't support port ranges, only allow single-port ranges
|
||||
if portRange.Start != portRange.End {
|
||||
continue
|
||||
}
|
||||
fr.Port = strconv.FormatUint(uint64(portRange.Start), 10)
|
||||
}
|
||||
expanded = append(expanded, &fr)
|
||||
}
|
||||
|
||||
return expanded
|
||||
}
|
||||
|
||||
// peerSupportsPortRanges checks if the peer version supports port ranges.
|
||||
func peerSupportsPortRanges(peerVer string) bool {
|
||||
if strings.Contains(peerVer, "dev") {
|
||||
return true
|
||||
}
|
||||
|
||||
meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer)
|
||||
return err == nil && meetMinVer
|
||||
}
|
||||
|
||||
// GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers.
|
||||
func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) {
|
||||
var isRoutingPeer bool
|
||||
var routes []*route.Route
|
||||
allSourcePeers := make(map[string]struct{}, len(a.Peers))
|
||||
|
||||
for _, resource := range a.NetworkResources {
|
||||
if !resource.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
var addSourcePeers bool
|
||||
|
||||
networkRoutingPeers, exists := routers[resource.NetworkID]
|
||||
if exists {
|
||||
if router, ok := networkRoutingPeers[peerID]; ok {
|
||||
isRoutingPeer, addSourcePeers = true, true
|
||||
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerID, router, resourcePolicies)...)
|
||||
}
|
||||
}
|
||||
|
||||
addedResourceRoute := false
|
||||
for _, policy := range resourcePolicies[resource.ID] {
|
||||
peers := a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups())
|
||||
if addSourcePeers {
|
||||
for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) {
|
||||
allSourcePeers[pID] = struct{}{}
|
||||
}
|
||||
} else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) {
|
||||
// add routes for the resource if the peer is in the distribution group
|
||||
for peerId, router := range networkRoutingPeers {
|
||||
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
|
||||
}
|
||||
addedResourceRoute = true
|
||||
}
|
||||
if addedResourceRoute {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return isRoutingPeer, routes, allSourcePeers
|
||||
}
|
||||
|
||||
// getNetworkResourcesRoutes convert the network resources list to routes list.
|
||||
func (a *Account) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerId string, router *routerTypes.NetworkRouter, resourcePolicies map[string][]*Policy) []*route.Route {
|
||||
resourceAppliedPolicies := resourcePolicies[resource.ID]
|
||||
|
||||
var routes []*route.Route
|
||||
// distribute the resource routes only if there is policy applied to it
|
||||
if len(resourceAppliedPolicies) > 0 {
|
||||
peer := a.GetPeer(peerId)
|
||||
if peer != nil {
|
||||
routes = append(routes, resource.ToRoute(peer, router))
|
||||
}
|
||||
}
|
||||
|
||||
return routes
|
||||
}
|
||||
|
||||
func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string {
|
||||
var dest []string
|
||||
for _, peerID := range inputPeers {
|
||||
if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) {
|
||||
dest = append(dest, peerID)
|
||||
}
|
||||
}
|
||||
return dest
|
||||
}
|
||||
|
||||
func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string {
|
||||
peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity
|
||||
for _, groupID := range groups {
|
||||
group := a.GetGroup(groupID)
|
||||
if group == nil {
|
||||
log.WithContext(ctx).Warnf("group %s doesn't exist under account %s, will continue map generation without it", groupID, a.Id)
|
||||
continue
|
||||
}
|
||||
|
||||
if group.IsGroupAll() || len(groups) == 1 {
|
||||
return group.Peers
|
||||
}
|
||||
|
||||
for _, peerID := range group.Peers {
|
||||
peerIDs[peerID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
ids := make([]string, 0, len(peerIDs))
|
||||
for peerID := range peerIDs {
|
||||
ids = append(ids, peerID)
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
// GetPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account.
|
||||
func (a *Account) GetPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule {
|
||||
routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes))
|
||||
|
||||
enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID)
|
||||
for _, route := range enabledRoutes {
|
||||
// If no access control groups are specified, accept all traffic.
|
||||
if len(route.AccessControlGroups) == 0 {
|
||||
defaultPermit := getDefaultPermit(route)
|
||||
routesFirewallRules = append(routesFirewallRules, defaultPermit...)
|
||||
continue
|
||||
}
|
||||
|
||||
distributionPeers := a.getDistributionGroupsPeers(route)
|
||||
|
||||
for _, accessGroup := range route.AccessControlGroups {
|
||||
policies := GetAllRoutePoliciesFromGroups(a, []string{accessGroup})
|
||||
rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers)
|
||||
routesFirewallRules = append(routesFirewallRules, rules...)
|
||||
}
|
||||
}
|
||||
|
||||
return routesFirewallRules
|
||||
}
|
||||
|
||||
func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule {
|
||||
var fwRules []*RouteFirewallRule
|
||||
for _, policy := range policies {
|
||||
if !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
rulePeers := a.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap)
|
||||
rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN)
|
||||
fwRules = append(fwRules, rules...)
|
||||
}
|
||||
}
|
||||
return fwRules
|
||||
}
|
||||
|
||||
func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer {
|
||||
distPeersWithPolicy := make(map[string]struct{})
|
||||
for _, id := range rule.Sources {
|
||||
group := a.Groups[id]
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, pID := range group.Peers {
|
||||
if pID == peerID {
|
||||
continue
|
||||
}
|
||||
_, distPeer := distributionPeers[pID]
|
||||
_, valid := validatedPeersMap[pID]
|
||||
if distPeer && valid && a.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) {
|
||||
distPeersWithPolicy[pID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy))
|
||||
for pID := range distPeersWithPolicy {
|
||||
peer := a.Peers[pID]
|
||||
if peer == nil {
|
||||
continue
|
||||
}
|
||||
distributionGroupPeers = append(distributionGroupPeers, peer)
|
||||
}
|
||||
return distributionGroupPeers
|
||||
}
|
||||
|
||||
func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} {
|
||||
distPeers := make(map[string]struct{})
|
||||
for _, id := range route.Groups {
|
||||
group := a.Groups[id]
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, pID := range group.Peers {
|
||||
distPeers[pID] = struct{}{}
|
||||
}
|
||||
}
|
||||
return distPeers
|
||||
}
|
||||
|
||||
func getDefaultPermit(route *route.Route) []*RouteFirewallRule {
|
||||
var rules []*RouteFirewallRule
|
||||
|
||||
sources := []string{"0.0.0.0/0"}
|
||||
if route.Network.Addr().Is6() {
|
||||
sources = []string{"::/0"}
|
||||
}
|
||||
rule := RouteFirewallRule{
|
||||
SourceRanges: sources,
|
||||
Action: string(PolicyTrafficActionAccept),
|
||||
Destination: route.Network.String(),
|
||||
Protocol: string(PolicyRuleProtocolALL),
|
||||
Domains: route.Domains,
|
||||
IsDynamic: route.IsDynamic(),
|
||||
RouteID: route.ID,
|
||||
}
|
||||
|
||||
rules = append(rules, &rule)
|
||||
|
||||
// dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally
|
||||
if route.IsDynamic() {
|
||||
ruleV6 := rule
|
||||
ruleV6.SourceRanges = []string{"::/0"}
|
||||
rules = append(rules, &ruleV6)
|
||||
}
|
||||
|
||||
return rules
|
||||
}
|
||||
|
||||
// GetAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups
|
||||
// and returns a list of policies that have rules with destinations matching the specified groups.
|
||||
func GetAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy {
|
||||
routePolicies := make([]*Policy, 0)
|
||||
for _, groupID := range accessControlGroups {
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, policy := range account.Policies {
|
||||
for _, rule := range policy.Rules {
|
||||
exist := slices.ContainsFunc(rule.Destinations, func(groupID string) bool {
|
||||
return groupID == group.ID
|
||||
})
|
||||
if exist {
|
||||
routePolicies = append(routePolicies, policy)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return routePolicies
|
||||
}
|
||||
|
||||
// GetPeerNetworkResourceFirewallRules gets the network resources firewall rules associated with a routing peer ID for the account.
|
||||
func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, routes []*route.Route, resourcePolicies map[string][]*Policy) []*RouteFirewallRule {
|
||||
routesFirewallRules := make([]*RouteFirewallRule, 0)
|
||||
|
||||
for _, route := range routes {
|
||||
if route.Peer != peer.Key {
|
||||
continue
|
||||
}
|
||||
resourceAppliedPolicies := resourcePolicies[string(route.GetResourceID())]
|
||||
distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups)
|
||||
|
||||
rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers)
|
||||
for _, rule := range rules {
|
||||
if len(rule.SourceRanges) > 0 {
|
||||
routesFirewallRules = append(routesFirewallRules, rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return routesFirewallRules
|
||||
}
|
||||
|
||||
// getPoliciesSourcePeers collects all unique peers from the source groups defined in the given policies.
|
||||
func getPoliciesSourcePeers(policies []*Policy, groups map[string]*Group) map[string]struct{} {
|
||||
sourcePeers := make(map[string]struct{})
|
||||
|
||||
for _, policy := range policies {
|
||||
for _, rule := range policy.Rules {
|
||||
for _, sourceGroup := range rule.Sources {
|
||||
group := groups[sourceGroup]
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, peer := range group.Peers {
|
||||
sourcePeers[peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sourcePeers
|
||||
}
|
||||
|
||||
// GetRoutesToSync returns the enabled routes for the peer ID and the routes
|
||||
// from the ACL peers that have distribution groups associated with the peer ID.
|
||||
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
|
||||
func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer) []*route.Route {
|
||||
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID)
|
||||
peerRoutesMembership := make(LookupMap)
|
||||
for _, r := range append(routes, peerDisabledRoutes...) {
|
||||
peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{}
|
||||
}
|
||||
|
||||
groupListMap := a.GetPeerGroups(peerID)
|
||||
for _, peer := range aclPeers {
|
||||
activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID)
|
||||
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap)
|
||||
filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
|
||||
routes = append(routes, filteredRoutes...)
|
||||
}
|
||||
|
||||
return routes
|
||||
}
|
||||
|
||||
func (a *Account) GetPeerGroups(peerID string) LookupMap {
|
||||
groupList := make(LookupMap)
|
||||
for groupID, group := range a.Groups {
|
||||
for _, id := range group.Peers {
|
||||
if id == peerID {
|
||||
groupList[groupID] = struct{}{}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return groupList
|
||||
}
|
||||
|
||||
// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership
|
||||
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route {
|
||||
var filteredRoutes []*route.Route
|
||||
for _, r := range routes {
|
||||
_, found := peerMemberships[string(r.GetHAUniqueID())]
|
||||
if !found {
|
||||
filteredRoutes = append(filteredRoutes, r)
|
||||
}
|
||||
}
|
||||
return filteredRoutes
|
||||
}
|
||||
|
||||
// filterRoutesByGroups returns a list with routes that have distribution groups in the group's map
|
||||
func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
|
||||
var filteredRoutes []*route.Route
|
||||
for _, r := range routes {
|
||||
for _, groupID := range r.Groups {
|
||||
_, found := groupListMap[groupID]
|
||||
if found {
|
||||
filteredRoutes = append(filteredRoutes, r)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return filteredRoutes
|
||||
}
|
||||
|
||||
// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves
|
||||
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
|
||||
// If the given is not a routing peer, then the lists are empty.
|
||||
func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) {
|
||||
|
||||
peer := a.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
// log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id)
|
||||
return enabledRoutes, disabledRoutes
|
||||
}
|
||||
|
||||
seenRoute := make(map[route.ID]struct{})
|
||||
|
||||
takeRoute := func(r *route.Route, id string) {
|
||||
if _, ok := seenRoute[r.ID]; ok {
|
||||
return
|
||||
}
|
||||
seenRoute[r.ID] = struct{}{}
|
||||
|
||||
if r.Enabled {
|
||||
r.Peer = peer.Key
|
||||
enabledRoutes = append(enabledRoutes, r)
|
||||
return
|
||||
}
|
||||
disabledRoutes = append(disabledRoutes, r)
|
||||
}
|
||||
|
||||
for _, r := range a.Routes {
|
||||
for _, groupID := range r.PeerGroups {
|
||||
group := a.GetGroup(groupID)
|
||||
if group == nil {
|
||||
log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id)
|
||||
continue
|
||||
}
|
||||
for _, id := range group.Peers {
|
||||
if id != peerID {
|
||||
continue
|
||||
}
|
||||
|
||||
newPeerRoute := r.Copy()
|
||||
newPeerRoute.Peer = id
|
||||
newPeerRoute.PeerGroups = nil
|
||||
newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map
|
||||
takeRoute(newPeerRoute, id)
|
||||
break
|
||||
}
|
||||
}
|
||||
if r.Peer == peerID {
|
||||
takeRoute(r.Copy(), peerID)
|
||||
}
|
||||
}
|
||||
|
||||
return enabledRoutes, disabledRoutes
|
||||
}
|
||||
|
||||
func (a *Account) addNetworksRoutingPeers(
|
||||
networkResourcesRoutes []*route.Route,
|
||||
peer *nbpeer.Peer,
|
||||
peersToConnect []*nbpeer.Peer,
|
||||
expiredPeers []*nbpeer.Peer,
|
||||
isRouter bool,
|
||||
sourcePeers map[string]struct{},
|
||||
) []*nbpeer.Peer {
|
||||
|
||||
networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes))
|
||||
for _, r := range networkResourcesRoutes {
|
||||
networkRoutesPeers[r.PeerID] = struct{}{}
|
||||
}
|
||||
|
||||
delete(sourcePeers, peer.ID)
|
||||
delete(networkRoutesPeers, peer.ID)
|
||||
|
||||
for _, existingPeer := range peersToConnect {
|
||||
delete(sourcePeers, existingPeer.ID)
|
||||
delete(networkRoutesPeers, existingPeer.ID)
|
||||
}
|
||||
for _, expPeer := range expiredPeers {
|
||||
delete(sourcePeers, expPeer.ID)
|
||||
delete(networkRoutesPeers, expPeer.ID)
|
||||
}
|
||||
|
||||
missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers))
|
||||
if isRouter {
|
||||
for p := range sourcePeers {
|
||||
missingPeers[p] = struct{}{}
|
||||
}
|
||||
}
|
||||
for p := range networkRoutesPeers {
|
||||
missingPeers[p] = struct{}{}
|
||||
}
|
||||
|
||||
for p := range missingPeers {
|
||||
if missingPeer := a.Peers[p]; missingPeer != nil {
|
||||
peersToConnect = append(peersToConnect, missingPeer)
|
||||
}
|
||||
}
|
||||
|
||||
return peersToConnect
|
||||
}
|
||||
|
||||
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
|
||||
peerGroups := a.GetPeerGroups(peerID)
|
||||
enabled := true
|
||||
for _, groupID := range a.DNSSettings.DisabledManagementGroups {
|
||||
_, found := peerGroups[groupID]
|
||||
if found {
|
||||
enabled = false
|
||||
break
|
||||
}
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
|
||||
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
|
||||
groupList := account.GetPeerGroups(peerID)
|
||||
|
||||
var peerNSGroups []*nbdns.NameServerGroup
|
||||
|
||||
for _, nsGroup := range account.NameServerGroups {
|
||||
if !nsGroup.Enabled {
|
||||
continue
|
||||
}
|
||||
for _, gID := range nsGroup.Groups {
|
||||
_, found := groupList[gID]
|
||||
if found {
|
||||
if !peerIsNameserver(account.GetPeer(peerID), nsGroup) {
|
||||
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return peerNSGroups
|
||||
}
|
||||
|
||||
// peerIsNameserver returns true if the peer is a nameserver for a nsGroup
|
||||
func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
if peer.IP.Equal(ns.IP.AsSlice()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *Account) initNetworkMapBuilder(validatedPeers map[string]struct{}) {
|
||||
if a.NetworkMapCache != nil {
|
||||
return
|
||||
}
|
||||
a.nmapInitOnce.Do(func() {
|
||||
a.NetworkMapCache = NewNetworkMapBuilder(a, validatedPeers)
|
||||
})
|
||||
}
|
||||
|
||||
func (a *Account) InitNetworkMapBuilderIfNeeded(validatedPeers map[string]struct{}) {
|
||||
a.initNetworkMapBuilder(validatedPeers)
|
||||
}
|
||||
|
||||
func (a *Account) GetPeerNetworkMapExp(
|
||||
ctx context.Context,
|
||||
peerID string,
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
validatedPeers map[string]struct{},
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *NetworkMap {
|
||||
a.initNetworkMapBuilder(validatedPeers)
|
||||
return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, validatedPeers, metrics)
|
||||
}
|
||||
|
||||
func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error {
|
||||
if a.NetworkMapCache == nil {
|
||||
return nil
|
||||
}
|
||||
return a.NetworkMapCache.OnPeerAddedIncremental(peerId)
|
||||
}
|
||||
|
||||
func (a *Account) OnPeerDeletedUpdNetworkMapCache(peerId string) error {
|
||||
if a.NetworkMapCache == nil {
|
||||
return nil
|
||||
}
|
||||
return a.NetworkMapCache.OnPeerDeleted(peerId)
|
||||
}
|
||||
|
||||
func (a *Account) UpdatePeerInNetworkMapCache(peer *nbpeer.Peer) {
|
||||
if a.NetworkMapCache == nil {
|
||||
return
|
||||
}
|
||||
a.NetworkMapCache.UpdatePeer(peer)
|
||||
}
|
||||
|
||||
func (a *Account) RecalculateNetworkMapCache(validatedPeers map[string]struct{}) {
|
||||
a.initNetworkMapBuilder(validatedPeers)
|
||||
}
|
||||
|
||||
// filterZoneRecordsForPeers filters DNS records to only include peers to connect.
|
||||
func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect, expiredPeers []*nbpeer.Peer) []nbdns.SimpleRecord {
|
||||
filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records))
|
||||
peerIPs := make(map[string]struct{})
|
||||
|
||||
// Add peer's own IP to include its own DNS records
|
||||
peerIPs[peer.IP.String()] = struct{}{}
|
||||
|
||||
for _, peerToConnect := range peersToConnect {
|
||||
peerIPs[peerToConnect.IP.String()] = struct{}{}
|
||||
}
|
||||
|
||||
for _, expiredPeer := range expiredPeers {
|
||||
peerIPs[expiredPeer.IP.String()] = struct{}{}
|
||||
}
|
||||
|
||||
for _, record := range customZone.Records {
|
||||
if _, exists := peerIPs[record.RData]; exists {
|
||||
filteredRecords = append(filteredRecords, record)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredRecords
|
||||
}
|
||||
1069
management/server/types/networkmap_golden_test.go
Normal file
1069
management/server/types/networkmap_golden_test.go
Normal file
File diff suppressed because it is too large
Load Diff
1761
management/server/types/networkmapbuilder.go
Normal file
1761
management/server/types/networkmapbuilder.go
Normal file
File diff suppressed because it is too large
Load Diff
106
management/server/update_buffer.go
Normal file
106
management/server/update_buffer.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
type UpdateBuffer struct {
|
||||
mu sync.Mutex
|
||||
cond *sync.Cond
|
||||
update *UpdateMessage
|
||||
closed bool
|
||||
metrics *telemetry.UpdateChannelMetrics
|
||||
overwriteCount int // Number of overwrites since last Pop
|
||||
lastPopTime time.Time // Time of last Pop
|
||||
}
|
||||
|
||||
func NewUpdateBuffer(metrics *telemetry.UpdateChannelMetrics) *UpdateBuffer {
|
||||
ub := &UpdateBuffer{metrics: metrics}
|
||||
ub.cond = sync.NewCond(&ub.mu)
|
||||
return ub
|
||||
}
|
||||
|
||||
func (b *UpdateBuffer) Push(update *UpdateMessage) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.update == nil || update.Update.NetworkMap.Serial > b.update.Update.NetworkMap.Serial || b.update.Update.NetworkMap.Serial == 0 {
|
||||
if b.update == nil {
|
||||
b.metrics.CountBufferPush()
|
||||
} else {
|
||||
b.metrics.CountBufferOverwrite()
|
||||
b.overwriteCount++
|
||||
}
|
||||
|
||||
b.update = update
|
||||
b.cond.Signal()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
b.metrics.CountBufferIgnore()
|
||||
}
|
||||
|
||||
func (b *UpdateBuffer) Pop(ctx context.Context) (*UpdateMessage, int, time.Duration, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
for b.update == nil && !b.closed {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, 0, 0, fmt.Errorf("context cancelled")
|
||||
default:
|
||||
}
|
||||
|
||||
waitCh := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
b.cond.Broadcast()
|
||||
case <-waitCh:
|
||||
// noop
|
||||
}
|
||||
}()
|
||||
b.cond.Wait()
|
||||
close(waitCh)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, 0, 0, fmt.Errorf("context cancelled")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
if b.closed {
|
||||
return nil, 0, 0, fmt.Errorf("buffer closed")
|
||||
}
|
||||
|
||||
msg := b.update
|
||||
overwrites := b.overwriteCount
|
||||
|
||||
// Calculate time since last pop
|
||||
now := time.Now()
|
||||
var timeSinceLastPop time.Duration
|
||||
if !b.lastPopTime.IsZero() {
|
||||
timeSinceLastPop = now.Sub(b.lastPopTime)
|
||||
}
|
||||
|
||||
// Reset counters
|
||||
b.update = nil
|
||||
b.overwriteCount = 0
|
||||
b.lastPopTime = now
|
||||
|
||||
return msg, overwrites, timeSinceLastPop, nil
|
||||
}
|
||||
|
||||
func (b *UpdateBuffer) Close() {
|
||||
b.mu.Lock()
|
||||
b.closed = true
|
||||
b.cond.Broadcast()
|
||||
b.mu.Unlock()
|
||||
}
|
||||
@@ -7,21 +7,39 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
const channelBufferSize = 100
|
||||
|
||||
type UpdateChannel struct {
|
||||
Important chan *UpdateMessage
|
||||
NetworkMap *UpdateBuffer
|
||||
}
|
||||
|
||||
func NewUpdateChannel(metrics *telemetry.UpdateChannelMetrics) *UpdateChannel {
|
||||
channel := make(chan *UpdateMessage, channelBufferSize)
|
||||
buffer := NewUpdateBuffer(metrics)
|
||||
|
||||
return &UpdateChannel{
|
||||
Important: channel,
|
||||
NetworkMap: buffer,
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UpdateChannel) Close() {
|
||||
close(u.Important)
|
||||
u.NetworkMap.Close()
|
||||
}
|
||||
|
||||
type UpdateMessage struct {
|
||||
Update *proto.SyncResponse
|
||||
NetworkMap *types.NetworkMap
|
||||
Update *proto.SyncResponse
|
||||
}
|
||||
|
||||
type PeersUpdateManager struct {
|
||||
// peerChannels is an update channel indexed by Peer.ID
|
||||
peerChannels map[string]chan *UpdateMessage
|
||||
peerChannels map[string]*UpdateChannel
|
||||
// channelsMux keeps the mutex to access peerChannels
|
||||
channelsMux *sync.RWMutex
|
||||
// metrics provides method to collect application metrics
|
||||
@@ -31,14 +49,14 @@ type PeersUpdateManager struct {
|
||||
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
|
||||
func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager {
|
||||
return &PeersUpdateManager{
|
||||
peerChannels: make(map[string]chan *UpdateMessage),
|
||||
peerChannels: make(map[string]*UpdateChannel),
|
||||
channelsMux: &sync.RWMutex{},
|
||||
metrics: metrics,
|
||||
}
|
||||
}
|
||||
|
||||
// SendUpdate sends update message to the peer's channel
|
||||
func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *UpdateMessage) {
|
||||
// SendImportantUpdate sends update message to the peer that needs to be received
|
||||
func (p *PeersUpdateManager) SendImportantUpdate(ctx context.Context, peerID string, update *UpdateMessage) {
|
||||
start := time.Now()
|
||||
var found, dropped bool
|
||||
|
||||
@@ -54,19 +72,42 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
found = true
|
||||
select {
|
||||
case channel <- update:
|
||||
log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID)
|
||||
case channel.Important <- update:
|
||||
// log.WithContext(ctx).Debugf("update was sent to important channel for peer %s", peerID)
|
||||
default:
|
||||
dropped = true
|
||||
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel))
|
||||
// log.WithContext(ctx).Warnf("important channel for peer %s is %d full or closed", peerID, len(channel.Important))
|
||||
}
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("peer %s has no channel", peerID)
|
||||
// log.WithContext(ctx).Debugf("peer %s has no important channel", peerID)
|
||||
}
|
||||
}
|
||||
|
||||
// SendNetworkMapUpdate sends a network map update to the peer's channel
|
||||
func (p *PeersUpdateManager) SendNetworkMapUpdate(ctx context.Context, peerID string, update *UpdateMessage) {
|
||||
start := time.Now()
|
||||
var found, dropped bool
|
||||
|
||||
p.channelsMux.RLock()
|
||||
|
||||
defer func() {
|
||||
p.channelsMux.RUnlock()
|
||||
if p.metrics != nil {
|
||||
p.metrics.UpdateChannelMetrics().CountSendUpdateDuration(time.Since(start), found, dropped)
|
||||
}
|
||||
}()
|
||||
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
found = true
|
||||
channel.NetworkMap.Push(update)
|
||||
// log.WithContext(ctx).Debugf("update was sent to network map buffer for peer %s", peerID)
|
||||
} else {
|
||||
// log.WithContext(ctx).Debugf("peer %s has no network map buffer", peerID)
|
||||
}
|
||||
}
|
||||
|
||||
// CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer.
|
||||
func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage {
|
||||
func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) *UpdateChannel {
|
||||
start := time.Now()
|
||||
|
||||
closed := false
|
||||
@@ -82,21 +123,21 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
closed = true
|
||||
delete(p.peerChannels, peerID)
|
||||
close(channel)
|
||||
channel.Close()
|
||||
}
|
||||
// mbragin: todo shouldn't it be more? or configurable?
|
||||
channel := make(chan *UpdateMessage, channelBufferSize)
|
||||
p.peerChannels[peerID] = channel
|
||||
|
||||
newChannel := NewUpdateChannel(p.metrics.UpdateChannelMetrics())
|
||||
p.peerChannels[peerID] = newChannel
|
||||
|
||||
log.WithContext(ctx).Debugf("opened updates channel for a peer %s", peerID)
|
||||
|
||||
return channel
|
||||
return newChannel
|
||||
}
|
||||
|
||||
func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) {
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
delete(p.peerChannels, peerID)
|
||||
close(channel)
|
||||
channel.Close()
|
||||
|
||||
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
|
||||
return
|
||||
@@ -176,3 +217,10 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool {
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// GetChannelCount returns the number of active peer channels
|
||||
func (p *PeersUpdateManager) GetChannelCount() int {
|
||||
p.channelsMux.RLock()
|
||||
defer p.channelsMux.RUnlock()
|
||||
return len(p.peerChannels)
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user