Compare commits

...

5 Commits

Author SHA1 Message Date
Zoltán Papp
b87173f47d Move header offset calculation to private values 2024-09-10 17:06:22 +02:00
Zoltán Papp
56badd7535 Fix message type 2024-09-10 16:34:11 +02:00
Zoltán Papp
1d233a5b5e Decrease the max size of the handshake message 2024-09-10 16:31:36 +02:00
Zoltán Papp
d3e7e6bc9c Rename constants 2024-09-10 16:23:34 +02:00
Zoltán Papp
a701148658 Eliminate gob usage from Relay protocol 2024-09-10 16:08:51 +02:00
7 changed files with 208 additions and 103 deletions

View File

@@ -14,8 +14,6 @@ import (
"github.com/netbirdio/netbird/relay/client/dialer/ws"
"github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/relay/messages/address"
auth2 "github.com/netbirdio/netbird/relay/messages/auth"
)
const (
@@ -240,31 +238,21 @@ func (c *Client) connect() error {
}
func (c *Client) handShake() error {
authMsg := &auth2.Msg{
AuthAlgorithm: auth2.AlgoHMACSHA256,
AdditionalData: c.authTokenStore.TokenBinary(),
}
authData, err := authMsg.Marshal()
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
if err != nil {
return fmt.Errorf("marshal auth message: %w", err)
}
msg, err := messages.MarshalHelloMsg(c.hashedID, authData)
if err != nil {
log.Errorf("failed to marshal hello message: %s", err)
log.Errorf("failed to marshal auth message: %s", err)
return err
}
_, err = c.relayConn.Write(msg)
if err != nil {
log.Errorf("failed to send hello message: %s", err)
log.Errorf("failed to send auth message: %s", err)
return err
}
buf := make([]byte, messages.MaxHandshakeSize)
buf := make([]byte, messages.MaxHandshakeRespSize)
n, err := c.readWithTimeout(buf)
if err != nil {
log.Errorf("failed to read hello response: %s", err)
log.Errorf("failed to read auth response: %s", err)
return err
}
@@ -273,29 +261,24 @@ func (c *Client) handShake() error {
return fmt.Errorf("validate version: %w", err)
}
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
msgType, err := messages.DetermineServerMessageType(buf[:n])
if err != nil {
log.Errorf("failed to determine message type: %s", err)
return err
}
if msgType != messages.MsgTypeHelloResponse {
if msgType != messages.MsgTypeAuthResponse {
log.Errorf("unexpected message type: %s", msgType)
return fmt.Errorf("unexpected message type")
}
additionalData, err := messages.UnmarshalHelloResponse(buf[messages.SizeOfProtoHeader:n])
addr, err := messages.UnmarshalAuthResponse(buf[:n])
if err != nil {
return err
}
addr, err := address.Unmarshal(additionalData)
if err != nil {
return fmt.Errorf("unmarshal address: %w", err)
}
c.muInstanceURL.Lock()
c.instanceURL = &RelayAddr{addr: addr.URL}
c.instanceURL = &RelayAddr{addr: addr}
c.muInstanceURL.Unlock()
return nil
}
@@ -329,14 +312,14 @@ func (c *Client) readLoop(relayConn net.Conn) {
continue
}
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
msgType, err := messages.DetermineServerMessageType(buf[:n])
if err != nil {
c.log.Errorf("failed to determine message type: %s", err)
c.bufPool.Put(bufPtr)
continue
}
if !c.handleMsg(msgType, buf[messages.SizeOfProtoHeader:n], bufPtr, hc, internallyStoppedFlag) {
if !c.handleMsg(msgType, buf[:n], bufPtr, hc, internallyStoppedFlag) {
break
}
}

View File

@@ -1,3 +1,4 @@
// Deprecated: This package is deprecated and will be removed in a future release.
package address
import (
@@ -18,13 +19,3 @@ func (addr *Address) Marshal() ([]byte, error) {
}
return buf.Bytes(), nil
}
func Unmarshal(data []byte) (*Address, error) {
var addr Address
buf := bytes.NewBuffer(data)
dec := gob.NewDecoder(buf)
if err := dec.Decode(&addr); err != nil {
return nil, fmt.Errorf("decode Address: %w", err)
}
return &addr, nil
}

View File

@@ -1,3 +1,4 @@
// Deprecated: This package is deprecated and will be removed in a future release.
package auth
import (
@@ -30,15 +31,6 @@ type Msg struct {
AdditionalData []byte
}
func (msg *Msg) Marshal() ([]byte, error) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
if err := enc.Encode(msg); err != nil {
return nil, fmt.Errorf("encode Msg: %w", err)
}
return buf.Bytes(), nil
}
func UnmarshalMsg(data []byte) (*Msg, error) {
var msg *Msg

View File

@@ -7,27 +7,36 @@ import (
)
const (
MsgTypeUnknown MsgType = 0
MsgTypeHello MsgType = 1
MaxHandshakeSize = 212
MaxHandshakeRespSize = 8192
CurrentProtocolVersion = 1
MsgTypeUnknown MsgType = 0
// Deprecated: Use MsgTypeAuth instead.
MsgTypeHello MsgType = 1
// Deprecated: Use MsgTypeAuthResponse instead.
MsgTypeHelloResponse MsgType = 2
MsgTypeTransport MsgType = 3
MsgTypeClose MsgType = 4
MsgTypeHealthCheck MsgType = 5
MsgTypeAuth = 6
MsgTypeAuthResponse = 7
SizeOfVersionByte = 1
SizeOfMsgType = 1
sizeOfVersionByte = 1
sizeOfMsgType = 1
SizeOfProtoHeader = SizeOfVersionByte + SizeOfMsgType
sizeOfCommonHeader = sizeOfVersionByte + sizeOfMsgType
sizeOfMagicByte = 4
headerSizeTransport = IDSize
headerSizeHello = sizeOfMagicByte + IDSize
headerSizeHelloResp = 0
headerSizeTransport = sizeOfCommonHeader + IDSize
MaxHandshakeSize = 8192
headerSizeHello = sizeOfCommonHeader + sizeOfMagicByte + IDSize
headerSizeHelloResp = sizeOfCommonHeader + sizeOfCommonHeader
CurrentProtocolVersion = 1
headerSizeAuth = sizeOfCommonHeader + sizeOfMagicByte + IDSize
headerSizeAuthResp = sizeOfCommonHeader
)
var (
@@ -47,6 +56,10 @@ func (m MsgType) String() string {
return "hello"
case MsgTypeHelloResponse:
return "hello response"
case MsgTypeAuth:
return "auth"
case MsgTypeAuthResponse:
return "auth response"
case MsgTypeTransport:
return "transport"
case MsgTypeClose:
@@ -58,13 +71,9 @@ func (m MsgType) String() string {
}
}
type HelloResponse struct {
InstanceAddress string
}
// ValidateVersion checks if the given version is supported by the protocol
func ValidateVersion(msg []byte) (int, error) {
if len(msg) < SizeOfVersionByte {
if len(msg) < sizeOfCommonHeader {
return 0, ErrInvalidMessageLength
}
version := int(msg[0])
@@ -76,14 +85,15 @@ func ValidateVersion(msg []byte) (int, error) {
// DetermineClientMessageType determines the message type from the first the message
func DetermineClientMessageType(msg []byte) (MsgType, error) {
if len(msg) < SizeOfMsgType {
if len(msg) < sizeOfCommonHeader {
return 0, ErrInvalidMessageLength
}
msgType := MsgType(msg[0])
msgType := MsgType(msg[1])
switch msgType {
case
MsgTypeHello,
MsgTypeAuth,
MsgTypeTransport,
MsgTypeClose,
MsgTypeHealthCheck:
@@ -95,14 +105,15 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) {
// DetermineServerMessageType determines the message type from the first the message
func DetermineServerMessageType(msg []byte) (MsgType, error) {
if len(msg) < SizeOfMsgType {
if len(msg) < sizeOfCommonHeader {
return 0, ErrInvalidMessageLength
}
msgType := MsgType(msg[0])
msgType := MsgType(msg[1])
switch msgType {
case
MsgTypeHelloResponse,
MsgTypeAuthResponse,
MsgTypeTransport,
MsgTypeClose,
MsgTypeHealthCheck:
@@ -112,6 +123,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
}
}
// Deprecated: Use MarshalAuthMsg instead.
// MarshalHelloMsg initial hello message
// The Hello message is the first message sent by a client after establishing a connection with the Relay server. This
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
@@ -122,12 +134,12 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}
msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeHello+len(additions))
msg := make([]byte, sizeOfCommonHeader+sizeOfMagicByte, sizeOfCommonHeader+headerSizeHello+len(additions))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeHello)
copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
copy(msg[sizeOfCommonHeader:sizeOfCommonHeader+sizeOfMagicByte], magicHeader)
msg = append(msg, peerID...)
msg = append(msg, additions...)
@@ -135,25 +147,27 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
return msg, nil
}
// Deprecated: Use UnmarshalAuthMsg instead.
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
// authenticate the client with the server.
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
if len(msg) < headerSizeHello {
return nil, nil, ErrInvalidMessageLength
}
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
if !bytes.Equal(msg[sizeOfCommonHeader:sizeOfCommonHeader+sizeOfMagicByte], magicHeader) {
return nil, nil, errors.New("invalid magic header")
}
return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil
return msg[sizeOfCommonHeader+sizeOfMagicByte : headerSizeHello], msg[headerSizeHello:], nil
}
// Deprecated: Use MarshalAuthResponse instead.
// MarshalHelloResponse creates a response message to the hello message.
// In case of success connection the server response with a Hello Response message. This message contains the server's
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
// servers.
func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeHelloResp+len(additionalData))
msg := make([]byte, headerSizeHelloResp, headerSizeHelloResp+len(additionalData))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeHelloResponse)
@@ -163,6 +177,7 @@ func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
return msg, nil
}
// Deprecated: Use UnmarshalAuthResponse instead.
// UnmarshalHelloResponse extracts the additional data from the hello response message.
func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
if len(msg) < headerSizeHelloResp {
@@ -171,11 +186,70 @@ func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
return msg, nil
}
// MarshalAuthMsg initial authentication message
// The Auth message is the first message sent by a client after establishing a connection with the Relay server. This
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
// close the network connection without any response.
func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}
msg := make([]byte, sizeOfCommonHeader+sizeOfMagicByte, sizeOfCommonHeader+headerSizeAuth+len(authPayload))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeAuth)
copy(msg[sizeOfCommonHeader:sizeOfCommonHeader+sizeOfMagicByte], magicHeader)
msg = append(msg, peerID...)
msg = append(msg, authPayload...)
return msg, nil
}
// UnmarshalAuthMsg extracts peerID and the auth payload from the message
func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
if len(msg) < headerSizeAuth {
return nil, nil, ErrInvalidMessageLength
}
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
return nil, nil, errors.New("invalid magic header")
}
return msg[sizeOfMagicByte:headerSizeAuth], msg[headerSizeAuth:], nil
}
// MarshalAuthResponse creates a response message to the auth.
// In case of success connection the server response with a AuthResponse message. This message contains the server's
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
// servers.
func MarshalAuthResponse(address string) ([]byte, error) {
ab := []byte(address)
msg := make([]byte, sizeOfCommonHeader, headerSizeAuthResp+len(ab))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeAuthResponse)
msg = append(msg, ab...)
return msg, nil
}
// UnmarshalAuthResponse it is a confirmation message to auth success
func UnmarshalAuthResponse(msg []byte) (string, error) {
if len(msg) < headerSizeAuthResp+1 { // +1 is the minimum expected size of the address
return "", ErrInvalidMessageLength
}
return string(msg), nil
}
// MarshalCloseMsg creates a close message.
// The close message is used to close the connection gracefully between the client and the server. The server and the
// client can send this message. After receiving this message, the server or client will close the connection.
func MarshalCloseMsg() []byte {
msg := make([]byte, SizeOfProtoHeader)
msg := make([]byte, sizeOfCommonHeader)
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeClose)
@@ -191,12 +265,12 @@ func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}
msg := make([]byte, SizeOfProtoHeader+headerSizeTransport, SizeOfProtoHeader+headerSizeTransport+len(payload))
msg := make([]byte, headerSizeTransport, headerSizeTransport+len(payload))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeTransport)
copy(msg[SizeOfProtoHeader:], peerID)
copy(msg[sizeOfCommonHeader:], peerID)
msg = append(msg, payload...)
@@ -209,7 +283,7 @@ func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
return nil, nil, ErrInvalidMessageLength
}
return buf[:headerSizeTransport], buf[headerSizeTransport:], nil
return buf[sizeOfCommonHeader:headerSizeTransport], buf[headerSizeTransport:], nil
}
// UnmarshalTransportID extracts the peerID from the transport message.
@@ -217,7 +291,7 @@ func UnmarshalTransportID(buf []byte) ([]byte, error) {
if len(buf) < headerSizeTransport {
return nil, ErrInvalidMessageLength
}
return buf[:headerSizeTransport], nil
return buf[sizeOfCommonHeader:headerSizeTransport], nil
}
// UpdateTransportMsg updates the peerID in the transport message.

View File

@@ -11,13 +11,37 @@ func TestMarshalHelloMsg(t *testing.T) {
t.Fatalf("error: %v", err)
}
receivedPeerID, _, err := UnmarshalHelloMsg(bHello[SizeOfProtoHeader:])
receivedPeerID, addition, err := UnmarshalHelloMsg(bHello)
if err != nil {
t.Fatalf("error: %v", err)
}
if string(receivedPeerID) != string(peerID) {
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
}
if len(addition) != 0 {
t.Errorf("expected empty addition, got %v", addition)
}
}
func TestMarshalAuthMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
bHello, err := MarshalAuthMsg(peerID, nil)
if err != nil {
t.Fatalf("error: %v", err)
}
receivedPeerID, addition, err := UnmarshalAuthMsg(bHello)
if err != nil {
t.Fatalf("error: %v", err)
}
if string(receivedPeerID) != string(peerID) {
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
}
if len(addition) != 0 {
t.Errorf("expected empty addition, got %v", addition)
}
}
func TestMarshalTransportMsg(t *testing.T) {
@@ -28,7 +52,15 @@ func TestMarshalTransportMsg(t *testing.T) {
t.Fatalf("error: %v", err)
}
id, respPayload, err := UnmarshalTransportMsg(msg[SizeOfProtoHeader:])
tid, err := UnmarshalTransportID(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if string(tid) != string(peerID) {
t.Errorf("expected %s, got %s", peerID, tid)
}
id, respPayload, err := UnmarshalTransportMsg(msg)
if err != nil {
t.Fatalf("error: %v", err)
}

View File

@@ -68,21 +68,19 @@ func (p *Peer) Work() {
return
}
msg := buf[:n]
_, err = messages.ValidateVersion(msg)
_, err = messages.ValidateVersion(buf[:n])
if err != nil {
p.log.Warnf("failed to validate protocol version: %s", err)
return
}
msgType, err := messages.DetermineClientMessageType(msg[messages.SizeOfVersionByte:])
msgType, err := messages.DetermineClientMessageType(buf[:n])
if err != nil {
p.log.Errorf("failed to determine message type: %s", err)
return
}
p.handleMsgType(ctx, msgType, hc, n, msg)
p.handleMsgType(ctx, msgType, hc, n, buf[:n])
}
}
@@ -175,7 +173,7 @@ func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Send
}
func (p *Peer) handleTransportMsg(msg []byte) {
peerID, err := messages.UnmarshalTransportID(msg[messages.SizeOfProtoHeader:])
peerID, err := messages.UnmarshalTransportID(msg)
if err != nil {
p.log.Errorf("failed to unmarshal transport message: %s", err)
return
@@ -188,7 +186,7 @@ func (p *Peer) handleTransportMsg(msg []byte) {
return
}
err = messages.UpdateTransportMsg(msg[messages.SizeOfProtoHeader:], p.idB)
err = messages.UpdateTransportMsg(msg, p.idB)
if err != nil {
p.log.Errorf("failed to update transport message: %s", err)
return

View File

@@ -21,9 +21,10 @@ import (
// Relay represents the relay server
type Relay struct {
metrics *metrics.Metrics
metricsCancel context.CancelFunc
validator auth.Validator
metrics *metrics.Metrics
metricsCancel context.CancelFunc
validator auth.Validator
validatorDummy auth.Validator // todo: this is just a dummy variable. Replace it with the proper validator
store *Store
instanceURL string
@@ -163,19 +164,41 @@ func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
return nil, fmt.Errorf("validate version from %s: %w", conn.RemoteAddr(), err)
}
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
msgType, err := messages.DetermineClientMessageType(buf[:n])
if err != nil {
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
}
if msgType != messages.MsgTypeHello {
return nil, fmt.Errorf("invalid message type from %s", conn.RemoteAddr())
var (
responseMsg []byte
peerID []byte
)
switch msgType {
case messages.MsgTypeHello:
responseMsg, err = r.handleHelloMsg(buf[:n], conn.RemoteAddr())
case messages.MsgTypeAuth:
responseMsg, err = r.handleAuthMsg(buf[:n], conn.RemoteAddr())
default:
return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr())
}
if err != nil {
return nil, err
}
peerID, authData, err := messages.UnmarshalHelloMsg(buf[messages.SizeOfProtoHeader:n])
_, err = conn.Write(responseMsg)
if err != nil {
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
}
return peerID, nil
}
func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, error) {
peerID, authData, err := messages.UnmarshalHelloMsg(buf)
if err != nil {
return nil, fmt.Errorf("unmarshal hello message: %w", err)
}
log.Warnf("peer is using depracated initial message type: %s (%s)", peerID, remoteAddr)
authMsg, err := authmsg.UnmarshalMsg(authData)
if err != nil {
@@ -183,24 +206,36 @@ func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
}
if err := r.validator.Validate(sha256.New, authMsg.AdditionalData); err != nil {
return nil, fmt.Errorf("validate %s (%s): %w", peerID, conn.RemoteAddr(), err)
return nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err)
}
addr := &address.Address{URL: r.instanceURL}
addrData, err := addr.Marshal()
if err != nil {
return nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, conn.RemoteAddr(), err)
return nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err)
}
msg, err := messages.MarshalHelloResponse(addrData)
responseMsg, err := messages.MarshalHelloResponse(addrData)
if err != nil {
return nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, conn.RemoteAddr(), err)
return nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err)
}
_, err = conn.Write(msg)
if err != nil {
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
}
return peerID, nil
return responseMsg, nil
}
func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, error) {
peerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
if err != nil {
return nil, fmt.Errorf("unmarshal hello message: %w", err)
}
// todo use the proper validator
if err := r.validatorDummy.Validate(sha256.New, authPayload); err != nil {
return nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err)
}
responseMsg, err := messages.MarshalAuthResponse(r.instanceURL)
if err != nil {
return nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err)
}
return responseMsg, nil
}