mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-22 01:12:29 -04:00
Compare commits
4 Commits
feature/us
...
feature/st
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
df101bf071 | ||
|
|
8393bf1b17 | ||
|
|
02a04958e7 | ||
|
|
000e99e7f3 |
@@ -73,8 +73,8 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config {
|
||||
return &tls.Config{
|
||||
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection, requiresCredSSP bool) *tls.Config {
|
||||
config := &tls.Config{
|
||||
InsecureSkipVerify: true, // We'll validate manually after handshake
|
||||
VerifyConnection: func(cs tls.ConnectionState) error {
|
||||
var certChain [][]byte
|
||||
@@ -93,4 +93,15 @@ func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tl
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// CredSSP (NLA) requires TLS 1.2 - it's incompatible with TLS 1.3
|
||||
if requiresCredSSP {
|
||||
config.MinVersion = tls.VersionTLS12
|
||||
config.MaxVersion = tls.VersionTLS12
|
||||
} else {
|
||||
config.MinVersion = tls.VersionTLS12
|
||||
config.MaxVersion = tls.VersionTLS13
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
@@ -6,11 +6,13 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/asn1"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
@@ -19,18 +21,34 @@ const (
|
||||
RDCleanPathVersion = 3390
|
||||
RDCleanPathProxyHost = "rdcleanpath.proxy.local"
|
||||
RDCleanPathProxyScheme = "ws"
|
||||
|
||||
rdpDialTimeout = 15 * time.Second
|
||||
|
||||
GeneralErrorCode = 1
|
||||
WSAETimedOut = 10060
|
||||
WSAEConnRefused = 10061
|
||||
WSAEConnAborted = 10053
|
||||
WSAEConnReset = 10054
|
||||
WSAEGenericError = 10050
|
||||
)
|
||||
|
||||
type RDCleanPathPDU struct {
|
||||
Version int64 `asn1:"tag:0,explicit"`
|
||||
Error []byte `asn1:"tag:1,explicit,optional"`
|
||||
Destination string `asn1:"utf8,tag:2,explicit,optional"`
|
||||
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
|
||||
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
|
||||
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
|
||||
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
|
||||
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
|
||||
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
|
||||
Version int64 `asn1:"tag:0,explicit"`
|
||||
Error RDCleanPathErr `asn1:"tag:1,explicit,optional"`
|
||||
Destination string `asn1:"utf8,tag:2,explicit,optional"`
|
||||
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
|
||||
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
|
||||
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
|
||||
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
|
||||
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
|
||||
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
|
||||
}
|
||||
|
||||
type RDCleanPathErr struct {
|
||||
ErrorCode int16 `asn1:"tag:0,explicit"`
|
||||
HTTPStatusCode int16 `asn1:"tag:1,explicit,optional"`
|
||||
WSALastError int16 `asn1:"tag:2,explicit,optional"`
|
||||
TLSAlertCode int8 `asn1:"tag:3,explicit,optional"`
|
||||
}
|
||||
|
||||
type RDCleanPathProxy struct {
|
||||
@@ -210,9 +228,13 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
|
||||
destination := conn.destination
|
||||
log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination)
|
||||
|
||||
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
|
||||
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to connect to %s: %v", destination, err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
conn.rdpConn = rdpConn
|
||||
@@ -220,6 +242,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
|
||||
_, err = rdpConn.Write(firstPacket)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write first packet: %v", err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -227,6 +250,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
|
||||
n, err := rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -269,3 +293,52 @@ func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
|
||||
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
data, err := asn1.Marshal(pdu)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal error PDU: %v", err)
|
||||
return
|
||||
}
|
||||
p.sendToWebSocket(conn, data)
|
||||
}
|
||||
|
||||
func errorToWSACode(err error) int16 {
|
||||
if err == nil {
|
||||
return WSAEGenericError
|
||||
}
|
||||
var netErr *net.OpError
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return WSAETimedOut
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return WSAETimedOut
|
||||
}
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return WSAEConnAborted
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return WSAEConnReset
|
||||
}
|
||||
return WSAEGenericError
|
||||
}
|
||||
|
||||
func newWSAError(err error) RDCleanPathPDU {
|
||||
return RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
Error: RDCleanPathErr{
|
||||
ErrorCode: GeneralErrorCode,
|
||||
WSALastError: errorToWSACode(err),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newHTTPError(statusCode int16) RDCleanPathPDU {
|
||||
return RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
Error: RDCleanPathErr{
|
||||
ErrorCode: GeneralErrorCode,
|
||||
HTTPStatusCode: statusCode,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package rdp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/asn1"
|
||||
"io"
|
||||
@@ -11,11 +12,17 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// MS-RDPBCGR: confusingly named, actually means PROTOCOL_HYBRID (CredSSP)
|
||||
protocolSSL = 0x00000001
|
||||
protocolHybridEx = 0x00000008
|
||||
)
|
||||
|
||||
func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination)
|
||||
|
||||
if pdu.Version != RDCleanPathVersion {
|
||||
p.sendRDCleanPathError(conn, "Unsupported version")
|
||||
p.sendRDCleanPathError(conn, newHTTPError(400))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -24,10 +31,13 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl
|
||||
destination = pdu.Destination
|
||||
}
|
||||
|
||||
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
|
||||
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to connect to %s: %v", destination, err)
|
||||
p.sendRDCleanPathError(conn, "Connection failed")
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
p.cleanupConnection(conn)
|
||||
return
|
||||
}
|
||||
@@ -40,6 +50,34 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl
|
||||
p.setupTLSConnection(conn, pdu)
|
||||
}
|
||||
|
||||
// detectCredSSPFromX224 checks if the X.224 response indicates NLA/CredSSP is required.
|
||||
// Per MS-RDPBCGR spec: byte 11 = TYPE_RDP_NEG_RSP (0x02), bytes 15-18 = selectedProtocol flags.
|
||||
// Returns (requiresTLS12, selectedProtocol, detectionSuccessful).
|
||||
func (p *RDCleanPathProxy) detectCredSSPFromX224(x224Response []byte) (bool, uint32, bool) {
|
||||
const minResponseLength = 19
|
||||
|
||||
if len(x224Response) < minResponseLength {
|
||||
return false, 0, false
|
||||
}
|
||||
|
||||
// Per X.224 specification:
|
||||
// x224Response[0] == 0x03: Length of X.224 header (3 bytes)
|
||||
// x224Response[5] == 0xD0: X.224 Data TPDU code
|
||||
if x224Response[0] != 0x03 || x224Response[5] != 0xD0 {
|
||||
return false, 0, false
|
||||
}
|
||||
|
||||
if x224Response[11] == 0x02 {
|
||||
flags := uint32(x224Response[15]) | uint32(x224Response[16])<<8 |
|
||||
uint32(x224Response[17])<<16 | uint32(x224Response[18])<<24
|
||||
|
||||
hasNLA := (flags & (protocolSSL | protocolHybridEx)) != 0
|
||||
return hasNLA, flags, true
|
||||
}
|
||||
|
||||
return false, 0, false
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
var x224Response []byte
|
||||
if len(pdu.X224ConnectionPDU) > 0 {
|
||||
@@ -47,7 +85,7 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
|
||||
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write X.224 PDU: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to forward X.224")
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -55,21 +93,32 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
|
||||
n, err := conn.rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
x224Response = response[:n]
|
||||
log.Debugf("Received X.224 Connection Confirm (%d bytes)", n)
|
||||
}
|
||||
|
||||
tlsConfig := p.getTLSConfigWithValidation(conn)
|
||||
requiresCredSSP, selectedProtocol, detected := p.detectCredSSPFromX224(x224Response)
|
||||
if detected {
|
||||
if requiresCredSSP {
|
||||
log.Warnf("Detected NLA/CredSSP (selectedProtocol: 0x%08X), forcing TLS 1.2 for compatibility", selectedProtocol)
|
||||
} else {
|
||||
log.Warnf("No NLA/CredSSP detected (selectedProtocol: 0x%08X), allowing up to TLS 1.3", selectedProtocol)
|
||||
}
|
||||
} else {
|
||||
log.Warnf("Could not detect RDP security protocol, allowing up to TLS 1.3")
|
||||
}
|
||||
|
||||
tlsConfig := p.getTLSConfigWithValidation(conn, requiresCredSSP)
|
||||
|
||||
tlsConn := tls.Client(conn.rdpConn, tlsConfig)
|
||||
conn.tlsConn = tlsConn
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
log.Errorf("TLS handshake failed: %v", err)
|
||||
p.sendRDCleanPathError(conn, "TLS handshake failed")
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -106,47 +155,6 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
|
||||
p.cleanupConnection(conn)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
if len(pdu.X224ConnectionPDU) > 0 {
|
||||
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
|
||||
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write X.224 PDU: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to forward X.224")
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]byte, 1024)
|
||||
n, err := conn.rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
|
||||
return
|
||||
}
|
||||
|
||||
responsePDU := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
X224ConnectionPDU: response[:n],
|
||||
ServerAddr: conn.destination,
|
||||
}
|
||||
|
||||
p.sendRDCleanPathPDU(conn, responsePDU)
|
||||
} else {
|
||||
responsePDU := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
ServerAddr: conn.destination,
|
||||
}
|
||||
p.sendRDCleanPathPDU(conn, responsePDU)
|
||||
}
|
||||
|
||||
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
|
||||
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
|
||||
|
||||
<-conn.ctx.Done()
|
||||
log.Debug("TCP connection context done, cleaning up")
|
||||
p.cleanupConnection(conn)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
data, err := asn1.Marshal(pdu)
|
||||
if err != nil {
|
||||
@@ -158,21 +166,6 @@ func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDClean
|
||||
p.sendToWebSocket(conn, data)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) {
|
||||
pdu := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
Error: []byte(errorMsg),
|
||||
}
|
||||
|
||||
data, err := asn1.Marshal(pdu)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal error PDU: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.sendToWebSocket(conn, data)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) {
|
||||
msgChan := make(chan []byte)
|
||||
errChan := make(chan error)
|
||||
|
||||
129
management/server/store/cache/dual_key_cache.go
vendored
Normal file
129
management/server/store/cache/dual_key_cache.go
vendored
Normal file
@@ -0,0 +1,129 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// DualKeyCache provides a caching mechanism where each entry has two keys:
|
||||
// - Primary key (e.g., objectID): used for accessing and invalidating specific entries
|
||||
// - Secondary key (e.g., accountID): used for bulk invalidation of all entries with the same secondary key
|
||||
type DualKeyCache[K1 comparable, K2 comparable, V any] struct {
|
||||
mu sync.RWMutex
|
||||
primaryIndex map[K1]V // Primary key -> Value
|
||||
secondaryIndex map[K2]map[K1]struct{} // Secondary key -> Set of primary keys
|
||||
reverseLookup map[K1]K2 // Primary key -> Secondary key
|
||||
}
|
||||
|
||||
// NewDualKeyCache creates a new dual-key cache
|
||||
func NewDualKeyCache[K1 comparable, K2 comparable, V any]() *DualKeyCache[K1, K2, V] {
|
||||
return &DualKeyCache[K1, K2, V]{
|
||||
primaryIndex: make(map[K1]V),
|
||||
secondaryIndex: make(map[K2]map[K1]struct{}),
|
||||
reverseLookup: make(map[K1]K2),
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a value from the cache using the primary key
|
||||
func (c *DualKeyCache[K1, K2, V]) Get(ctx context.Context, primaryKey K1) (V, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
value, ok := c.primaryIndex[primaryKey]
|
||||
return value, ok
|
||||
}
|
||||
|
||||
// Set stores a value in the cache with both primary and secondary keys
|
||||
func (c *DualKeyCache[K1, K2, V]) Set(ctx context.Context, primaryKey K1, secondaryKey K2, value V) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if oldSecondaryKey, exists := c.reverseLookup[primaryKey]; exists {
|
||||
if primaryKeys, ok := c.secondaryIndex[oldSecondaryKey]; ok {
|
||||
delete(primaryKeys, primaryKey)
|
||||
if len(primaryKeys) == 0 {
|
||||
delete(c.secondaryIndex, oldSecondaryKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.primaryIndex[primaryKey] = value
|
||||
c.reverseLookup[primaryKey] = secondaryKey
|
||||
|
||||
if _, exists := c.secondaryIndex[secondaryKey]; !exists {
|
||||
c.secondaryIndex[secondaryKey] = make(map[K1]struct{})
|
||||
}
|
||||
c.secondaryIndex[secondaryKey][primaryKey] = struct{}{}
|
||||
}
|
||||
|
||||
// InvalidateByPrimaryKey removes an entry using the primary key
|
||||
func (c *DualKeyCache[K1, K2, V]) InvalidateByPrimaryKey(ctx context.Context, primaryKey K1) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if secondaryKey, exists := c.reverseLookup[primaryKey]; exists {
|
||||
if primaryKeys, ok := c.secondaryIndex[secondaryKey]; ok {
|
||||
delete(primaryKeys, primaryKey)
|
||||
if len(primaryKeys) == 0 {
|
||||
delete(c.secondaryIndex, secondaryKey)
|
||||
}
|
||||
}
|
||||
delete(c.reverseLookup, primaryKey)
|
||||
}
|
||||
|
||||
delete(c.primaryIndex, primaryKey)
|
||||
}
|
||||
|
||||
// InvalidateBySecondaryKey removes all entries with the given secondary key
|
||||
func (c *DualKeyCache[K1, K2, V]) InvalidateBySecondaryKey(ctx context.Context, secondaryKey K2) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
primaryKeys, exists := c.secondaryIndex[secondaryKey]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
for primaryKey := range primaryKeys {
|
||||
delete(c.primaryIndex, primaryKey)
|
||||
delete(c.reverseLookup, primaryKey)
|
||||
}
|
||||
|
||||
delete(c.secondaryIndex, secondaryKey)
|
||||
}
|
||||
|
||||
// InvalidateAll removes all entries from the cache
|
||||
func (c *DualKeyCache[K1, K2, V]) InvalidateAll(ctx context.Context) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.primaryIndex = make(map[K1]V)
|
||||
c.secondaryIndex = make(map[K2]map[K1]struct{})
|
||||
c.reverseLookup = make(map[K1]K2)
|
||||
}
|
||||
|
||||
// Size returns the number of entries in the cache
|
||||
func (c *DualKeyCache[K1, K2, V]) Size() int {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
return len(c.primaryIndex)
|
||||
}
|
||||
|
||||
// GetOrSet retrieves a value from the cache, or sets it using the provided function if not found
|
||||
// The loadFunc should return both the value and the secondary key (extracted from the value)
|
||||
func (c *DualKeyCache[K1, K2, V]) GetOrSet(ctx context.Context, primaryKey K1, loadFunc func() (V, K2, error)) (V, error) {
|
||||
if value, ok := c.Get(ctx, primaryKey); ok {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
value, secondaryKey, err := loadFunc()
|
||||
if err != nil {
|
||||
var zero V
|
||||
return zero, err
|
||||
}
|
||||
|
||||
c.Set(ctx, primaryKey, secondaryKey, value)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
77
management/server/store/cache/single_key_cache.go
vendored
Normal file
77
management/server/store/cache/single_key_cache.go
vendored
Normal file
@@ -0,0 +1,77 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// SingleKeyCache provides a simple caching mechanism with a single key
|
||||
type SingleKeyCache[K comparable, V any] struct {
|
||||
mu sync.RWMutex
|
||||
cache map[K]V // Key -> Value
|
||||
}
|
||||
|
||||
// NewSingleKeyCache creates a new single-key cache
|
||||
func NewSingleKeyCache[K comparable, V any]() *SingleKeyCache[K, V] {
|
||||
return &SingleKeyCache[K, V]{
|
||||
cache: make(map[K]V),
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a value from the cache using the key
|
||||
func (c *SingleKeyCache[K, V]) Get(ctx context.Context, key K) (V, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
value, ok := c.cache[key]
|
||||
return value, ok
|
||||
}
|
||||
|
||||
// Set stores a value in the cache with the given key
|
||||
func (c *SingleKeyCache[K, V]) Set(ctx context.Context, key K, value V) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.cache[key] = value
|
||||
}
|
||||
|
||||
// Invalidate removes an entry using the key
|
||||
func (c *SingleKeyCache[K, V]) Invalidate(ctx context.Context, key K) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
delete(c.cache, key)
|
||||
}
|
||||
|
||||
// InvalidateAll removes all entries from the cache
|
||||
func (c *SingleKeyCache[K, V]) InvalidateAll(ctx context.Context) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.cache = make(map[K]V)
|
||||
}
|
||||
|
||||
// Size returns the number of entries in the cache
|
||||
func (c *SingleKeyCache[K, V]) Size() int {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
return len(c.cache)
|
||||
}
|
||||
|
||||
// GetOrSet retrieves a value from the cache, or sets it using the provided function if not found
|
||||
func (c *SingleKeyCache[K, V]) GetOrSet(ctx context.Context, key K, loadFunc func() (V, error)) (V, error) {
|
||||
if value, ok := c.Get(ctx, key); ok {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
value, err := loadFunc()
|
||||
if err != nil {
|
||||
var zero V
|
||||
return zero, err
|
||||
}
|
||||
|
||||
c.Set(ctx, key, value)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
242
management/server/store/cache/triple_key_cache.go
vendored
Normal file
242
management/server/store/cache/triple_key_cache.go
vendored
Normal file
@@ -0,0 +1,242 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// TripleKeyCache provides a caching mechanism where each entry has three keys:
|
||||
// - Primary key (K1): used for accessing and invalidating specific entries
|
||||
// - Secondary key (K2): used for bulk invalidation of all entries with the same secondary key
|
||||
// - Tertiary key (K3): used for bulk invalidation of all entries with the same tertiary key
|
||||
type TripleKeyCache[K1 comparable, K2 comparable, K3 comparable, V any] struct {
|
||||
mu sync.RWMutex
|
||||
primaryIndex map[K1]V // Primary key -> Value
|
||||
secondaryIndex map[K2]map[K1]struct{} // Secondary key -> Set of primary keys
|
||||
tertiaryIndex map[K3]map[K1]struct{} // Tertiary key -> Set of primary keys
|
||||
reverseLookup map[K1]keyPair[K2, K3] // Primary key -> Secondary and Tertiary keys
|
||||
}
|
||||
|
||||
type keyPair[K2 comparable, K3 comparable] struct {
|
||||
secondary K2
|
||||
tertiary K3
|
||||
}
|
||||
|
||||
// NewTripleKeyCache creates a new triple-key cache
|
||||
func NewTripleKeyCache[K1 comparable, K2 comparable, K3 comparable, V any]() *TripleKeyCache[K1, K2, K3, V] {
|
||||
return &TripleKeyCache[K1, K2, K3, V]{
|
||||
primaryIndex: make(map[K1]V),
|
||||
secondaryIndex: make(map[K2]map[K1]struct{}),
|
||||
tertiaryIndex: make(map[K3]map[K1]struct{}),
|
||||
reverseLookup: make(map[K1]keyPair[K2, K3]),
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a value from the cache using the primary key
|
||||
func (c *TripleKeyCache[K1, K2, K3, V]) Get(ctx context.Context, primaryKey K1) (V, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
value, ok := c.primaryIndex[primaryKey]
|
||||
return value, ok
|
||||
}
|
||||
|
||||
// Set stores a value in the cache with primary, secondary, and tertiary keys
|
||||
func (c *TripleKeyCache[K1, K2, K3, V]) Set(ctx context.Context, primaryKey K1, secondaryKey K2, tertiaryKey K3, value V) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if oldKeys, exists := c.reverseLookup[primaryKey]; exists {
|
||||
if primaryKeys, ok := c.secondaryIndex[oldKeys.secondary]; ok {
|
||||
delete(primaryKeys, primaryKey)
|
||||
if len(primaryKeys) == 0 {
|
||||
delete(c.secondaryIndex, oldKeys.secondary)
|
||||
}
|
||||
}
|
||||
if primaryKeys, ok := c.tertiaryIndex[oldKeys.tertiary]; ok {
|
||||
delete(primaryKeys, primaryKey)
|
||||
if len(primaryKeys) == 0 {
|
||||
delete(c.tertiaryIndex, oldKeys.tertiary)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.primaryIndex[primaryKey] = value
|
||||
c.reverseLookup[primaryKey] = keyPair[K2, K3]{
|
||||
secondary: secondaryKey,
|
||||
tertiary: tertiaryKey,
|
||||
}
|
||||
|
||||
if _, exists := c.secondaryIndex[secondaryKey]; !exists {
|
||||
c.secondaryIndex[secondaryKey] = make(map[K1]struct{})
|
||||
}
|
||||
c.secondaryIndex[secondaryKey][primaryKey] = struct{}{}
|
||||
|
||||
if _, exists := c.tertiaryIndex[tertiaryKey]; !exists {
|
||||
c.tertiaryIndex[tertiaryKey] = make(map[K1]struct{})
|
||||
}
|
||||
c.tertiaryIndex[tertiaryKey][primaryKey] = struct{}{}
|
||||
}
|
||||
|
||||
// InvalidateByPrimaryKey removes an entry using the primary key
|
||||
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateByPrimaryKey(ctx context.Context, primaryKey K1) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if keys, exists := c.reverseLookup[primaryKey]; exists {
|
||||
if primaryKeys, ok := c.secondaryIndex[keys.secondary]; ok {
|
||||
delete(primaryKeys, primaryKey)
|
||||
if len(primaryKeys) == 0 {
|
||||
delete(c.secondaryIndex, keys.secondary)
|
||||
}
|
||||
}
|
||||
if primaryKeys, ok := c.tertiaryIndex[keys.tertiary]; ok {
|
||||
delete(primaryKeys, primaryKey)
|
||||
if len(primaryKeys) == 0 {
|
||||
delete(c.tertiaryIndex, keys.tertiary)
|
||||
}
|
||||
}
|
||||
delete(c.reverseLookup, primaryKey)
|
||||
}
|
||||
|
||||
delete(c.primaryIndex, primaryKey)
|
||||
}
|
||||
|
||||
// InvalidateBySecondaryKey removes all entries with the given secondary key
|
||||
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateBySecondaryKey(ctx context.Context, secondaryKey K2) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
primaryKeys, exists := c.secondaryIndex[secondaryKey]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
for primaryKey := range primaryKeys {
|
||||
if keys, ok := c.reverseLookup[primaryKey]; ok {
|
||||
if tertiaryPrimaryKeys, exists := c.tertiaryIndex[keys.tertiary]; exists {
|
||||
delete(tertiaryPrimaryKeys, primaryKey)
|
||||
if len(tertiaryPrimaryKeys) == 0 {
|
||||
delete(c.tertiaryIndex, keys.tertiary)
|
||||
}
|
||||
}
|
||||
}
|
||||
delete(c.primaryIndex, primaryKey)
|
||||
delete(c.reverseLookup, primaryKey)
|
||||
}
|
||||
|
||||
delete(c.secondaryIndex, secondaryKey)
|
||||
}
|
||||
|
||||
// InvalidateByTertiaryKey removes all entries with the given tertiary key
|
||||
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateByTertiaryKey(ctx context.Context, tertiaryKey K3) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
primaryKeys, exists := c.tertiaryIndex[tertiaryKey]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
for primaryKey := range primaryKeys {
|
||||
if keys, ok := c.reverseLookup[primaryKey]; ok {
|
||||
if secondaryPrimaryKeys, exists := c.secondaryIndex[keys.secondary]; exists {
|
||||
delete(secondaryPrimaryKeys, primaryKey)
|
||||
if len(secondaryPrimaryKeys) == 0 {
|
||||
delete(c.secondaryIndex, keys.secondary)
|
||||
}
|
||||
}
|
||||
}
|
||||
delete(c.primaryIndex, primaryKey)
|
||||
delete(c.reverseLookup, primaryKey)
|
||||
}
|
||||
|
||||
delete(c.tertiaryIndex, tertiaryKey)
|
||||
}
|
||||
|
||||
// InvalidateAll removes all entries from the cache
|
||||
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateAll(ctx context.Context) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.primaryIndex = make(map[K1]V)
|
||||
c.secondaryIndex = make(map[K2]map[K1]struct{})
|
||||
c.tertiaryIndex = make(map[K3]map[K1]struct{})
|
||||
c.reverseLookup = make(map[K1]keyPair[K2, K3])
|
||||
}
|
||||
|
||||
// Size returns the number of entries in the cache
|
||||
func (c *TripleKeyCache[K1, K2, K3, V]) Size() int {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
return len(c.primaryIndex)
|
||||
}
|
||||
|
||||
// GetOrSet retrieves a value from the cache, or sets it using the provided function if not found
|
||||
// The loadFunc should return the value, secondary key, and tertiary key (extracted from the value)
|
||||
func (c *TripleKeyCache[K1, K2, K3, V]) GetOrSet(ctx context.Context, primaryKey K1, loadFunc func() (V, K2, K3, error)) (V, error) {
|
||||
if value, ok := c.Get(ctx, primaryKey); ok {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
value, secondaryKey, tertiaryKey, err := loadFunc()
|
||||
if err != nil {
|
||||
var zero V
|
||||
return zero, err
|
||||
}
|
||||
|
||||
c.Set(ctx, primaryKey, secondaryKey, tertiaryKey, value)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// GetOrSetBySecondaryKey retrieves a value from the cache using the secondary key, or sets it using the provided function if not found
|
||||
// The loadFunc should return the value, primary key, secondary key, and tertiary key
|
||||
func (c *TripleKeyCache[K1, K2, K3, V]) GetOrSetBySecondaryKey(ctx context.Context, secondaryKey K2, loadFunc func() (V, K1, K3, error)) (V, error) {
|
||||
c.mu.RLock()
|
||||
if primaryKeys, exists := c.secondaryIndex[secondaryKey]; exists && len(primaryKeys) > 0 {
|
||||
for primaryKey := range primaryKeys {
|
||||
if value, ok := c.primaryIndex[primaryKey]; ok {
|
||||
c.mu.RUnlock()
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
value, primaryKey, tertiaryKey, err := loadFunc()
|
||||
if err != nil {
|
||||
var zero V
|
||||
return zero, err
|
||||
}
|
||||
|
||||
c.Set(ctx, primaryKey, secondaryKey, tertiaryKey, value)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// GetOrSetByTertiaryKey retrieves a value from the cache using the tertiary key, or sets it using the provided function if not found
|
||||
// The loadFunc should return the value, primary key, secondary key, and tertiary key
|
||||
func (c *TripleKeyCache[K1, K2, K3, V]) GetOrSetByTertiaryKey(ctx context.Context, tertiaryKey K3, loadFunc func() (V, K1, K2, error)) (V, error) {
|
||||
c.mu.RLock()
|
||||
if primaryKeys, exists := c.tertiaryIndex[tertiaryKey]; exists && len(primaryKeys) > 0 {
|
||||
for primaryKey := range primaryKeys {
|
||||
if value, ok := c.primaryIndex[primaryKey]; ok {
|
||||
c.mu.RUnlock()
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
value, primaryKey, secondaryKey, err := loadFunc()
|
||||
if err != nil {
|
||||
var zero V
|
||||
return zero, err
|
||||
}
|
||||
|
||||
c.Set(ctx, primaryKey, secondaryKey, tertiaryKey, value)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
)
|
||||
|
||||
@@ -14,6 +15,8 @@ type StoreMetrics struct {
|
||||
persistenceDurationMicro metric.Int64Histogram
|
||||
persistenceDurationMs metric.Int64Histogram
|
||||
transactionDurationMs metric.Int64Histogram
|
||||
queryDurationMs metric.Int64Histogram
|
||||
queryCounter metric.Int64Counter
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
@@ -59,12 +62,29 @@ func NewStoreMetrics(ctx context.Context, meter metric.Meter) (*StoreMetrics, er
|
||||
return nil, err
|
||||
}
|
||||
|
||||
queryDurationMs, err := meter.Int64Histogram("management.store.query.duration.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithDescription("Duration of database query operations with operation type and table name"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
queryCounter, err := meter.Int64Counter("management.store.query.count",
|
||||
metric.WithDescription("Count of database query operations with operation type, table name, and status"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &StoreMetrics{
|
||||
globalLockAcquisitionDurationMicro: globalLockAcquisitionDurationMicro,
|
||||
globalLockAcquisitionDurationMs: globalLockAcquisitionDurationMs,
|
||||
persistenceDurationMicro: persistenceDurationMicro,
|
||||
persistenceDurationMs: persistenceDurationMs,
|
||||
transactionDurationMs: transactionDurationMs,
|
||||
queryDurationMs: queryDurationMs,
|
||||
queryCounter: queryCounter,
|
||||
ctx: ctx,
|
||||
}, nil
|
||||
}
|
||||
@@ -85,3 +105,13 @@ func (metrics *StoreMetrics) CountPersistenceDuration(duration time.Duration) {
|
||||
func (metrics *StoreMetrics) CountTransactionDuration(duration time.Duration) {
|
||||
metrics.transactionDurationMs.Record(metrics.ctx, duration.Milliseconds())
|
||||
}
|
||||
|
||||
// CountStoreOperation records a store operation with its method name, status, and duration
|
||||
func (metrics *StoreMetrics) CountStoreOperation(method string, duration time.Duration) {
|
||||
attrs := []attribute.KeyValue{
|
||||
attribute.String("method", method),
|
||||
}
|
||||
|
||||
metrics.queryDurationMs.Record(metrics.ctx, duration.Milliseconds(), metric.WithAttributes(attrs...))
|
||||
metrics.queryCounter.Add(metrics.ctx, 1, metric.WithAttributes(attrs...))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user